Skip to content

Commit aef945a

Browse files
committed
Merge commit 'b93eefd2b6108cebe58c79fdcb71421d542c23ab'
2 parents ef4973b + b93eefd commit aef945a

File tree

22 files changed

+461
-216
lines changed

22 files changed

+461
-216
lines changed

.github/workflows/build-macos.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ jobs:
106106
source ~/.venv/bin/activate
107107
echo "PATH is '$PATH'"
108108
ccache --zero-stats
109+
export PATH="/opt/homebrew/opt/llvm@19/bin:$PATH"
110+
export CC="/opt/homebrew/opt/llvm@19/bin/clang"
111+
export CXX="/opt/homebrew/opt/llvm@19/bin/clang++"
112+
export CXXFLAGS="-stdlib=libc++"
113+
export LDFLAGS="-L/opt/homebrew/opt/llvm@19/lib"
114+
which clang++
115+
clang++ --version
109116
make dev-install
110117
- name: CCache Stats
111118
run: ccache --print-stats

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,15 +1051,17 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10511051
let arguments = (ins
10521052
TT_Ptr:$base,
10531053
Variadic<I32>:$shape,
1054-
Variadic<I64>:$strides
1054+
Variadic<I64>:$strides,
1055+
DefaultValuedAttr<TT_PaddingOptionAttr, "::mlir::triton::PaddingOption::PAD_ZERO">:$padding
10551056
);
10561057

10571058
let results = (outs TT_TensorDescType:$result);
10581059

10591060
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
10601061

10611062
let builders = [
1062-
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
1063+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger,
1064+
"triton::PaddingOption":$padding)>
10631065
];
10641066

10651067
let extraClassDeclaration = [{

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,8 @@ OpFoldResult AdvanceOp::fold(FoldAdaptor adaptor) {
10191019
//-- MakeTensorDescOp --
10201020
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
10211021
Value base, ValueRange shape, ValueRange strides,
1022-
ArrayRef<int32_t> blockShape,
1023-
bool isSignedInteger) {
1022+
ArrayRef<int32_t> blockShape, bool isSignedInteger,
1023+
triton::PaddingOption padding) {
10241024
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
10251025
if (!ptrTy) {
10261026
llvm::report_fatal_error("Expected pointer type");
@@ -1030,7 +1030,8 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
10301030
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
10311031
auto descTy =
10321032
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
1033-
return build(builder, state, descTy, base, shape, strides);
1033+
auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding);
1034+
return build(builder, state, descTy, base, shape, strides, paddingAttr);
10341035
}
10351036

10361037
// The following ops, including `call`, `func`, and `return` are copied and

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,21 @@ struct Descriptor {
5959
Value base;
6060
ValueRange shape;
6161
ValueRange strides;
62+
Value paddingOption;
6263
};
6364

6465
Descriptor unpackDescriptor(TensorDescType type, ValueRange pack) {
6566
int rank = type.getBlockType().getRank();
66-
assert(pack.size() == 1 + 2 * static_cast<size_t>(rank) &&
67+
assert(pack.size() == 1 + 2 * static_cast<size_t>(rank) + 1 &&
6768
"Expected tensor descriptors to consist of a pointer, "
68-
"followed by 'rank' shape values and 'rank' stride values.");
69+
"followed by 'rank' shape values and 'rank' stride values, "
70+
"followed by a padding option value.");
6971

7072
Descriptor res;
7173
res.base = pack[0];
7274
res.shape = pack.slice(1, rank);
7375
res.strides = pack.slice(1 + rank, rank);
76+
res.paddingOption = pack[1 + 2 * rank];
7477
return res;
7578
}
7679

@@ -211,16 +214,30 @@ Value generateMask(OpBuilder &builder, const Location &loc,
211214
}
212215

213216
Value generateOther(OpBuilder &builder, Location loc, Type scalarTy,
214-
ArrayRef<int64_t> blockShape) {
217+
ArrayRef<int64_t> blockShape,
218+
Value paddingOption = nullptr) {
215219
auto blockTy = RankedTensorType::get(blockShape, scalarTy);
216-
auto attr = builder.getZeroAttr(blockTy);
217-
return builder.create<arith::ConstantOp>(loc, attr);
220+
if (paddingOption && mlir::isa<FloatType>(scalarTy)) {
221+
auto floatTy = mlir::cast<FloatType>(scalarTy);
222+
auto nan = llvm::APFloat::getNaN(floatTy.getFloatSemantics());
223+
auto nanValue = builder.create<arith::ConstantOp>(
224+
loc,
225+
SplatElementsAttr::get(blockTy, builder.getFloatAttr(floatTy, nan)));
226+
auto zeroValue = builder.create<arith::ConstantOp>(
227+
loc, SplatElementsAttr::get(blockTy, builder.getZeroAttr(floatTy)));
228+
return builder.create<mlir::arith::SelectOp>(loc, paddingOption, nanValue,
229+
zeroValue);
230+
} else {
231+
auto attr = builder.getZeroAttr(blockTy);
232+
return builder.create<arith::ConstantOp>(loc, attr);
233+
}
218234
}
219235

220-
Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy) {
236+
Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy,
237+
Value paddingOption = nullptr) {
221238
auto blockTy = descTy.getSignlessBlockType();
222239
return generateOther(builder, loc, blockTy.getElementType(),
223-
blockTy.getShape());
240+
blockTy.getShape(), paddingOption);
224241
}
225242

226243
SmallVector<mlir::Value> castToI64(OpBuilder &builder,
@@ -237,12 +254,17 @@ struct RewriteMakeTensorDesc : OpConversionPattern<triton::MakeTensorDescOp> {
237254
llvm::LogicalResult
238255
matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor,
239256
ConversionPatternRewriter &rewriter) const override {
240-
SmallVector<mlir::Value> ptrShapeStrides;
241-
llvm::append_values(ptrShapeStrides, adaptor.getBase());
242-
llvm::append_range(ptrShapeStrides,
257+
SmallVector<mlir::Value> ptrShapeStridesPaddingOption;
258+
llvm::append_values(ptrShapeStridesPaddingOption, adaptor.getBase());
259+
llvm::append_range(ptrShapeStridesPaddingOption,
243260
castToI64(rewriter, adaptor.getShape()));
244-
llvm::append_range(ptrShapeStrides, adaptor.getStrides());
245-
rewriter.replaceOpWithMultiple(op, {ptrShapeStrides});
261+
llvm::append_range(ptrShapeStridesPaddingOption, adaptor.getStrides());
262+
auto paddingOption = rewriter.create<mlir::arith::ConstantOp>(
263+
op.getLoc(), rewriter.getI1Type(),
264+
rewriter.getBoolAttr(adaptor.getPadding() ==
265+
triton::PaddingOption::PAD_NAN));
266+
llvm::append_values(ptrShapeStridesPaddingOption, paddingOption);
267+
rewriter.replaceOpWithMultiple(op, {ptrShapeStridesPaddingOption});
246268
return mlir::success();
247269
}
248270
};
@@ -258,12 +280,11 @@ struct RewriteLoadPattern : OpConversionPattern<triton::DescriptorLoadOp> {
258280
auto descTy = op.getDesc().getType();
259281
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
260282
auto offsets = castToI64(rewriter, op.getIndices());
261-
283+
auto other = generateOther(rewriter, loc, descTy, desc.paddingOption);
262284
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
263285
op, generatePtr(rewriter, loc, blockShape, desc, offsets),
264-
generateMask(rewriter, loc, blockShape, desc, offsets),
265-
generateOther(rewriter, loc, descTy), triton::CacheModifier::NONE,
266-
triton::EvictionPolicy::NORMAL, false);
286+
generateMask(rewriter, loc, blockShape, desc, offsets), other,
287+
triton::CacheModifier::NONE, triton::EvictionPolicy::NORMAL, false);
267288
newLoad->setAttrs(filterSegmentSizes(op->getAttrs()));
268289

269290
return llvm::success();
@@ -327,7 +348,7 @@ struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
327348
rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset());
328349
auto other = generateOther(rewriter, loc,
329350
descTy.getSignlessBlockType().getElementType(),
330-
blockShape);
351+
blockShape, desc.paddingOption);
331352
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
332353
op, ptr, mask, other, triton::CacheModifier::NONE,
333354
triton::EvictionPolicy::NORMAL, false);
@@ -471,13 +492,14 @@ class TritonRewriteTensorDescriptorToPointerPass
471492
converter.addConversion([](mlir::triton::TensorDescType t,
472493
llvm::SmallVectorImpl<mlir::Type> &out) {
473494
// We convert a tensor descriptor into an pointer, and a shape and stride
474-
// for each dimension, i.e., we create 1+2*rank values. Note that tensor
475-
// descriptors may be signed/unsigned integers whereas pointers should
476-
// always be signless.
495+
// for each dimension, and padding option. i.e., we create 1+2*rank+1
496+
// values. Note that tensor descriptors may be signed/unsigned integers
497+
// whereas pointers should always be signless.
477498
auto tensorType = t.getSignlessBlockType();
478499
out.push_back(triton::getPointerType(tensorType.getElementType()));
479500
out.insert(out.end(), 2 * tensorType.getRank(),
480501
mlir::IntegerType::get(t.getContext(), 64));
502+
out.push_back(mlir::IntegerType::get(t.getContext(), 1));
481503
return mlir::success();
482504
});
483505

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
304304
return failure();
305305
}
306306

307+
auto fillMode = (op.getPadding() == triton::PaddingOption::PAD_NAN) ? 1 : 0;
308+
307309
builder.create<TensormapCreateOp>(
308310
loc,
309311
/*desc_ptr=*/tmaPtr,
@@ -315,7 +317,7 @@ LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op,
315317
/*elem_type*/ builder.getI32IntegerAttr(*elemTypeEnum),
316318
/*interleave_layout*/ builder.getI32IntegerAttr(0),
317319
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzleMode),
318-
/*fill_mode=*/builder.getI32IntegerAttr(0));
320+
/*fill_mode=*/builder.getI32IntegerAttr(fillMode));
319321
return success();
320322
}
321323

python/src/ir.cc

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
6464
}
6565
}
6666

67+
// Function to parse a comma-separated string into a vector of C-style strings
68+
llvm::SmallVector<const char *, 3>
69+
parseCommaSeparatedValues(const std::string &input,
70+
llvm::SmallVector<std::string, 3> &storage) {
71+
llvm::SmallVector<StringRef, 3> split;
72+
llvm::SmallVector<const char *, 3> result;
73+
StringRef(input.c_str()).split(split, ',');
74+
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
75+
// StringRefs are not always null-terminated.
76+
// The purpose for this storage pattern is to
77+
// produce a collection of C-strings that are.
78+
storage.push_back(str.str());
79+
return storage.back().c_str();
80+
});
81+
return result;
82+
}
83+
6784
// Run the pass manager under a source manager diagnostic handler, which
6885
// enables emitted MLIR diagnostics to directly reference Python source
6986
// code. This diagnostic handler supports filtering diagnostic info by
@@ -100,6 +117,43 @@ struct TritonSourceMgrDiagnosticHandler : public SourceMgrDiagnosticHandler {
100117
llvm::SourceMgr sourceMgr;
101118
};
102119

120+
TritonSourceMgrDiagnosticHandler
121+
setupTritonDiagnosticHandler(MLIRContext *context) {
122+
bool showOperations = false, showStacktraces = false, showRemarks = false,
123+
showWarnings = false;
124+
125+
if (auto enableDiagnostics =
126+
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
127+
!enableDiagnostics.empty()) {
128+
llvm::SmallVector<std::string, 3> storage;
129+
parseCommaSeparatedValues(enableDiagnostics, storage);
130+
for (auto &str : storage) {
131+
if (str == "warnings") {
132+
showWarnings = true;
133+
} else if (str == "remarks") {
134+
showRemarks = true;
135+
} else if (str == "stacktraces") {
136+
showStacktraces = true;
137+
} else if (str == "operations") {
138+
showOperations = true;
139+
}
140+
// we show errors by default, so no need to set it
141+
}
142+
}
143+
144+
DiagnosticSeverity minSeverity =
145+
showWarnings ? DiagnosticSeverity::Warning : DiagnosticSeverity::Error;
146+
minSeverity = showRemarks ? DiagnosticSeverity::Remark : minSeverity;
147+
148+
context->printOpOnDiagnostic(showOperations);
149+
context->printStackTraceOnDiagnostic(showStacktraces);
150+
if (showStacktraces) {
151+
context->disableMultithreading();
152+
}
153+
154+
return TritonSourceMgrDiagnosticHandler(context, minSeverity);
155+
}
156+
103157
std::string locationToString(Location loc) {
104158
std::string str;
105159
llvm::raw_string_ostream os(str);
@@ -108,23 +162,6 @@ std::string locationToString(Location loc) {
108162
return str;
109163
}
110164

111-
// Function to parse a comma-separated string into a vector of C-style strings
112-
llvm::SmallVector<const char *, 3>
113-
parseCommaSeparatedValues(const std::string &input,
114-
llvm::SmallVector<std::string, 3> &storage) {
115-
llvm::SmallVector<StringRef, 3> split;
116-
llvm::SmallVector<const char *, 3> result;
117-
StringRef(input.c_str()).split(split, ',');
118-
llvm::transform(split, std::back_inserter(result), [&storage](StringRef str) {
119-
// StringRefs are not always null-terminated.
120-
// The purpose for this storage pattern is to
121-
// produce a collection of C-strings that are.
122-
storage.push_back(str.str());
123-
return storage.back().c_str();
124-
});
125-
return result;
126-
}
127-
128165
void outputWarning(Location loc, const std::string &msg) {
129166
std::string locStr = locationToString(loc);
130167

@@ -663,7 +700,12 @@ void init_triton_ir(py::module &&m) {
663700
.def("walk",
664701
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
665702
self.walk(fn);
666-
});
703+
})
704+
.def("verify_with_diagnostics", [](ModuleOp &self) {
705+
TritonSourceMgrDiagnosticHandler handler =
706+
setupTritonDiagnosticHandler(self.getContext());
707+
return succeeded(verify(self.getOperation()));
708+
});
667709

668710
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
669711
return mlir::cast<Attribute>(DenseIntElementsAttr::get(
@@ -1762,9 +1804,10 @@ void init_triton_ir(py::module &&m) {
17621804
.def("create_make_tensor_descriptor",
17631805
[](TritonOpBuilder &self, Value &base, std::vector<Value> &shape,
17641806
std::vector<Value> &strides, std::vector<int32_t> &tensorShape,
1765-
bool isSignedInteger) -> Value {
1807+
bool isSignedInteger, PaddingOption paddingOption) -> Value {
17661808
return self.create<MakeTensorDescOp>(base, shape, strides,
1767-
tensorShape, isSignedInteger);
1809+
tensorShape, isSignedInteger,
1810+
paddingOption);
17681811
});
17691812

17701813
py::class_<PassManager>(m, "pass_manager", py::module_local())
@@ -1862,42 +1905,8 @@ void init_triton_ir(py::module &&m) {
18621905
self.enableTiming();
18631906
}
18641907

1865-
// setting up diagnostics
1866-
bool showOperations = false, showStacktraces = false,
1867-
showRemarks = false, showWarnings = false;
1868-
1869-
if (auto enableDiagnostics =
1870-
triton::tools::getStrEnv("MLIR_ENABLE_DIAGNOSTICS");
1871-
!enableDiagnostics.empty()) {
1872-
llvm::SmallVector<std::string, 3> storage;
1873-
parseCommaSeparatedValues(enableDiagnostics, storage);
1874-
for (auto &str : storage) {
1875-
if (str == "warnings") {
1876-
showWarnings = true;
1877-
} else if (str == "remarks") {
1878-
showRemarks = true;
1879-
} else if (str == "stacktraces") {
1880-
showStacktraces = true;
1881-
} else if (str == "operations") {
1882-
showOperations = true;
1883-
}
1884-
// we show errors by default, so no need to set it
1885-
}
1886-
}
1887-
1888-
DiagnosticSeverity minSeverity = showWarnings
1889-
? DiagnosticSeverity::Warning
1890-
: DiagnosticSeverity::Error;
1891-
minSeverity =
1892-
showRemarks ? DiagnosticSeverity::Remark : minSeverity;
1893-
1894-
TritonSourceMgrDiagnosticHandler diagHandler(context, minSeverity);
1895-
1896-
context->printOpOnDiagnostic(showOperations);
1897-
context->printStackTraceOnDiagnostic(showStacktraces);
1898-
if (showStacktraces) {
1899-
context->disableMultithreading();
1900-
}
1908+
TritonSourceMgrDiagnosticHandler diagHandler =
1909+
setupTritonDiagnosticHandler(context);
19011910
if (failed(self.run(mod.getOperation())))
19021911
throw std::runtime_error("PassManager::run failed");
19031912
},

python/test/gluon/test_frontend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,3 +2251,11 @@ def test_infer_layout_for_padded_shared(target):
22512251
}
22522252
}
22532253
""")
2254+
2255+
2256+
@filecheck_test
2257+
@gluon.jit
2258+
def test_layout_zeros():
2259+
# CHECK: #blocked = #ttg.blocked
2260+
# CHECK: arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked>
2261+
ttgl.zeros([128], ttgl.float32, layout=ttgl.BlockedLayout([1], [32], [4], [0]))

0 commit comments

Comments
 (0)