Skip to content

Commit 975446e

Browse files
authored
[BACKEND] Fix convert to triton gpu for CF ops (#6909)
Functions with unstructured control flow passing tensors were not working as the function conversion pattern was converting all the blocks and creating illegal IR.
1 parent 1cbcf9f commit 975446e

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
101101
return true;
102102
return false;
103103
});
104+
addDynamicallyLegalOp<triton::FuncOp>([](triton::FuncOp funcOp) -> bool {
105+
for (auto arg : funcOp.getArguments()) {
106+
if (auto tensor = dyn_cast<RankedTensorType>(arg.getType())) {
107+
if (!tensor.getEncoding())
108+
return false;
109+
}
110+
}
111+
return true;
112+
});
104113
}
105114

106115
bool TritonGPUConversionTarget::isDynamicallyLegal(

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,14 +481,17 @@ class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
481481
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
482482
ConversionPatternRewriter &rewriter) const override {
483483
auto converter = getTypeConverter();
484+
TypeConverter::SignatureConversion result(op.getNumArguments());
484485
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
485486
op, op.getName(), op.getFunctionType());
486487
addNamedAttrs(newOp, adaptor.getAttributes());
487488
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
488489
newOp.getBody().end());
489-
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
490-
return failure();
491-
490+
// Convert just the entry block. The remaining unstructured control flow is
491+
// converted by br patterns.
492+
if (!newOp.getBody().empty())
493+
rewriter.applySignatureConversion(&newOp.getBody().front(), result,
494+
converter);
492495
return success();
493496
}
494497
};

test/Conversion/triton_to_tritongpu.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,16 @@ tt.func @ub_poison() {
155155
%0 = ub.poison : tensor<128x64xf16>
156156
tt.return
157157
}
158+
159+
// -----
160+
161+
// CHECK-LABEL: @cf_br
162+
tt.func @cf_br(%ptr: !tt.ptr<i32>) {
163+
%cst = arith.constant dense<1> : tensor<128xi32>
164+
// cf.br ^bb1(%{{.+}} : tensor<128xi32, #{{.+}}>)
165+
cf.br ^bb1(%cst : tensor<128xi32>)
166+
^bb1(%arg0: tensor<128xi32>):
167+
%ptrs = tt.splat %ptr : !tt.ptr<i32> -> tensor<128x!tt.ptr<i32>>
168+
tt.store %ptrs, %arg0 : tensor<128x!tt.ptr<i32>>
169+
tt.return
170+
}

0 commit comments

Comments
 (0)