@@ -12,7 +12,8 @@ namespace mlir::triton::gpu {
1212
1313namespace {
1414
15- template <typename T> bool hasEncoding (Value value) {
15+ template <typename T>
16+ bool hasEncoding (Value value) {
1617 auto type = value.getType ();
1718 if (auto tensorType = dyn_cast<TensorOrMemDesc>(type)) {
1819 auto encoding = tensorType.getEncoding ();
@@ -25,7 +26,7 @@ bool hasDotOperandEncoding(Value value) {
2526 return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
2627}
2728
28- } // namespace
29+ } // namespace
2930
3031// ===----------------------------------------------------------------------===//
3132// Canonicalizer
@@ -36,16 +37,13 @@ struct CanonicalizeConvertFromReshape
3637 : public mlir::OpRewritePattern<triton::ReshapeOp> {
3738 using OpRewritePattern::OpRewritePattern;
3839
39- mlir::LogicalResult
40- matchAndRewrite (triton::ReshapeOp op,
41- PatternRewriter &rewriter) const override {
40+ mlir::LogicalResult matchAndRewrite (
41+ triton::ReshapeOp op, PatternRewriter &rewriter) const override {
4242 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
43- if (!convert)
44- return failure ();
43+ if (!convert) return failure ();
4544 if (isExpensiveView (convert.getSrc ().getType (), op.getType ()))
4645 return failure ();
47- if (!op.getAllowReorder () || op.getEfficientLayout ())
48- return failure ();
46+ if (!op.getAllowReorder () || op.getEfficientLayout ()) return failure ();
4947
5048 rewriter.replaceOpWithNewOp <triton::ReshapeOp>(
5149 op, op.getType (), convert.getSrc (), op.getAllowReorder ());
@@ -58,12 +56,10 @@ struct CanonicalizeConvertFromHistogram
5856 : public mlir::OpRewritePattern<triton::HistogramOp> {
5957 using OpRewritePattern::OpRewritePattern;
6058
61- mlir::LogicalResult
62- matchAndRewrite (triton::HistogramOp op,
63- PatternRewriter &rewriter) const override {
59+ mlir::LogicalResult matchAndRewrite (
60+ triton::HistogramOp op, PatternRewriter &rewriter) const override {
6461 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
65- if (!convert)
66- return failure ();
62+ if (!convert) return failure ();
6763 rewriter.replaceOpWithNewOp <triton::HistogramOp>(
6864 op, op->getResult (0 ).getType (), convert.getSrc ());
6965 return mlir::success ();
@@ -79,15 +75,13 @@ struct CanonicalizeConvertFromHistogram
7975struct CanonicalizeConvertFromGatherSource : public OpRewritePattern <GatherOp> {
8076 using OpRewritePattern::OpRewritePattern;
8177
82- mlir::LogicalResult
83- matchAndRewrite ( GatherOp op, PatternRewriter &rewriter) const override {
78+ mlir::LogicalResult matchAndRewrite (
79+ GatherOp op, PatternRewriter &rewriter) const override {
8480 // Don't do this if the compiler picked an optimized layout.
85- if (op.getEfficientLayout ())
86- return failure ();
81+ if (op.getEfficientLayout ()) return failure ();
8782
8883 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
89- if (!convert)
90- return failure ();
84+ if (!convert) return failure ();
9185
9286 rewriter.replaceOpWithNewOp <GatherOp>(op, convert.getSrc (), op.getIndices (),
9387 op.getAxis ());
@@ -100,13 +94,15 @@ struct CanonicalizeConvertFromAlloc
10094 : public mlir::OpRewritePattern<triton::gpu::LocalAllocOp> {
10195 using OpRewritePattern::OpRewritePattern;
10296
103- mlir::LogicalResult
104- matchAndRewrite (triton::gpu::LocalAllocOp op,
105- PatternRewriter &rewriter) const override {
106- if (!op.getSrc ())
107- return failure ();
97+ mlir::LogicalResult matchAndRewrite (
98+ triton::gpu::LocalAllocOp op, PatternRewriter &rewriter) const override {
99+ if (!op.getSrc ()) return failure ();
108100 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
109- if (!convert)
101+ if (!convert) return failure ();
102+ // LocalAllocOp lowering doesn't support going from DotOperandEncoding
103+ // to SharedEncoding, so we want to keep this layout conversion.
104+ if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
105+ convert.getSrc ().getType ().getEncoding ()))
110106 return failure ();
111107 rewriter.replaceOpWithNewOp <triton::gpu::LocalAllocOp>(
112108 op, op->getResult (0 ).getType (), convert.getSrc ());
@@ -119,12 +115,10 @@ struct CanonicalizeConvertFromLocalStore
119115 : public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
120116 using OpRewritePattern::OpRewritePattern;
121117
122- mlir::LogicalResult
123- matchAndRewrite (triton::gpu::LocalStoreOp op,
124- PatternRewriter &rewriter) const override {
118+ mlir::LogicalResult matchAndRewrite (
119+ triton::gpu::LocalStoreOp op, PatternRewriter &rewriter) const override {
125120 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
126- if (!convert)
127- return failure ();
121+ if (!convert) return failure ();
128122 rewriter.replaceOpWithNewOp <triton::gpu::LocalStoreOp>(op, convert.getSrc (),
129123 op.getDst ());
130124 return mlir::success ();
@@ -135,19 +129,16 @@ struct CanonicalizeConvertFromSplit
135129 : public mlir::OpRewritePattern<triton::SplitOp> {
136130 using OpRewritePattern::OpRewritePattern;
137131
138- mlir::LogicalResult
139- matchAndRewrite (triton::SplitOp op,
140- PatternRewriter &rewriter) const override {
132+ mlir::LogicalResult matchAndRewrite (
133+ triton::SplitOp op, PatternRewriter &rewriter) const override {
141134 auto convert = op.getSrc ().getDefiningOp <ConvertLayoutOp>();
142- if (!convert)
143- return failure ();
135+ if (!convert) return failure ();
144136 auto srcEncoding = convert.getSrc ().getType ().getEncoding ();
145137 // Multiple source layout can give the same output layout, if the source
146138 // layout of the convert gives the same destination layout we can skip the
147139 // convert.
148140 auto dstEncoding = inferDstEncoding (op, srcEncoding);
149- if (dstEncoding != op.getOutLHS ().getType ().getEncoding ())
150- return failure ();
141+ if (dstEncoding != op.getOutLHS ().getType ().getEncoding ()) return failure ();
151142 rewriter.replaceOpWithNewOp <triton::SplitOp>(op, convert.getSrc ());
152143 return mlir::success ();
153144 }
@@ -157,9 +148,8 @@ struct CanonicalizeConvertFromConvert
157148 : public OpRewritePattern<ConvertLayoutOp> {
158149 using OpRewritePattern::OpRewritePattern;
159150
160- mlir::LogicalResult
161- matchAndRewrite (ConvertLayoutOp op,
162- PatternRewriter &rewriter) const override {
151+ mlir::LogicalResult matchAndRewrite (
152+ ConvertLayoutOp op, PatternRewriter &rewriter) const override {
163153 // Convert to the same layout is redundant.
164154 if (op->getResultTypes () == op->getOperandTypes ()) {
165155 rewriter.replaceOp (op, op->getOperands ());
@@ -170,22 +160,21 @@ struct CanonicalizeConvertFromConvert
170160 // heuristic to accommodate fused attention.
171161 auto srcType = op.getSrc ().getType ();
172162 auto dstType = op.getType ();
173- if (mlir::isa <DotOperandEncodingAttr>(dstType.getEncoding ()) &&
174- mlir::isa <NvidiaMmaEncodingAttr>(srcType.getEncoding ()))
163+ if (mlir::isa_and_nonnull <DotOperandEncodingAttr>(dstType.getEncoding ()) &&
164+ mlir::isa_and_nonnull <NvidiaMmaEncodingAttr>(srcType.getEncoding ()))
175165 return failure ();
176166
177167 // for hopper MMAv3
178- if (mlir::isa <SharedEncodingAttr>(dstType.getEncoding ()) &&
179- mlir::isa <NvidiaMmaEncodingAttr>(srcType.getEncoding ()) &&
168+ if (mlir::isa_and_nonnull <SharedEncodingAttr>(dstType.getEncoding ()) &&
169+ mlir::isa_and_nonnull <NvidiaMmaEncodingAttr>(srcType.getEncoding ()) &&
180170 llvm::any_of (op.getResult ().getUsers (), [](Operation *dot) {
181171 return dot->hasTrait <OpTrait::DotLike>();
182172 })) {
183173 return failure ();
184174 }
185175
186176 Operation *arg = op.getSrc ().getDefiningOp ();
187- if (!arg)
188- return failure ();
177+ if (!arg) return failure ();
189178
190179 // cvt(reshape) -> reshape
191180 if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
@@ -233,8 +222,7 @@ struct CanonicalizeConvertFromConvert
233222
234223 // cvt(cat) -> cat
235224 if (auto cat = dyn_cast<CatOp>(arg)) {
236- if (isExpensiveCat (cat, op.getType ().getEncoding ()))
237- return failure ();
225+ if (isExpensiveCat (cat, op.getType ().getEncoding ())) return failure ();
238226
239227 rewriter.replaceOpWithNewOp <CatOp>(op, op->getResult (0 ).getType (),
240228 cat.getOperands ());
@@ -291,15 +279,14 @@ LogicalResult UpcastMXFPOp::verify() {
291279
292280 auto xTy = getSrc ().getType ();
293281 auto scaleTy = getScale ().getType ();
294- Builder b (getContext ());
295- if (xTy.getElementType () != b.getBF16Type () &&
296- xTy.getElementType () != b.getF16Type () &&
297- xTy.getElementType () != b.getI8Type ()) {
298- return emitOpError (
299- " element type of the first operand must be bf16/fp16 or i8" );
282+
283+ if (xTy.getElementType () != BFloat16Type::get (getContext ()) &&
284+ xTy.getElementType () != Float16Type::get (getContext ()) &&
285+ xTy.getElementType () != IntegerType::get (getContext (), 8 )) {
286+ return emitOpError (" element type of the first operand must be bf16 or i8" );
300287 }
301288
302- if (scaleTy.getElementType () != b. getI8Type ( )) {
289+ if (scaleTy.getElementType () != IntegerType::get ( getContext (), 8 )) {
303290 return emitOpError (" element type of the second operand must be uint8" );
304291 }
305292
@@ -373,14 +360,12 @@ LogicalResult UpcastMXFPOp::verify() {
373360 return success ();
374361}
375362
376- RankedTensorType
377- UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
378- ScaleDotElemType inputElemType,
379- Type outputElemType) {
363+ RankedTensorType UpcastMXFPOp::deduceOutputType (
364+ TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType,
365+ Type outputElemType) {
380366 MLIRContext *ctx = inputTensor.getContext ();
381367 auto xTy = inputTensor.getType ();
382- if (inputElemType != ScaleDotElemType::E2M1)
383- return xTy;
368+ if (inputElemType != ScaleDotElemType::E2M1) return xTy;
384369
385370 auto xShape = xTy.getShape ();
386371 auto newShape = llvm::to_vector (xShape);
@@ -466,17 +451,13 @@ void LocalAllocOp::getEffects(
466451}
467452
468453OpFoldResult LocalAllocOp::fold (FoldAdaptor adaptor) {
469- if (getType ().getMutableMemory ())
470- return {};
454+ if (getType ().getMutableMemory ()) return {};
471455 auto src = getSrc ();
472- if (!src)
473- return {};
456+ if (!src) return {};
474457 auto localLoadOp = src.getDefiningOp <LocalLoadOp>();
475- if (!localLoadOp)
476- return {};
458+ if (!localLoadOp) return {};
477459 auto loadSrc = localLoadOp.getSrc ();
478- if (loadSrc.getType () != getType ())
479- return {};
460+ if (loadSrc.getType () != getType ()) return {};
480461 return loadSrc;
481462}
482463
0 commit comments