Skip to content

Commit e671c0f

Browse files
authored
[BACKEND] Move lowering of CF as the last step of conversion to LLVM (#7213)
This prevents problem with analysis picking up argument from blocks that have been removed.
1 parent 886a8fe commit e671c0f

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
8484
// Split a block after the call.
8585
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
8686
rewriter.setInsertionPointToEnd(ifBlock);
87-
rewriter.create<cf::BranchOp>(loc, thenBlock);
87+
rewriter.create<LLVM::BrOp>(loc, thenBlock);
8888
rewriter.setInsertionPointToEnd(prevBlock);
89-
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
89+
rewriter.create<LLVM::CondBrOp>(loc, condition, ifBlock, thenBlock);
9090
rewriter.setInsertionPointToStart(thenBlock);
9191
}
9292

lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ inline SmallVector<Value> applyCombineOp(Location loc,
9797
thenBlockArgs.push_back(undef);
9898
thenBlock->addArgument(ty, loc);
9999
}
100-
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
101-
thenBlock, thenBlockArgs);
100+
rewriter.create<LLVM::CondBrOp>(loc, pred, &newCombine, combineArgs,
101+
thenBlock, thenBlockArgs);
102102

103103
// Split a block after the call.
104104
rewriter.setInsertionPointToEnd(&newCombine);
105-
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
105+
rewriter.replaceOpWithNewOp<LLVM::BrOp>(returnOp, results, thenBlock);
106106
rewriter.setInsertionPointToStart(thenBlock);
107107
return SmallVector<Value>(thenBlock->getArguments());
108108
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,3 +2397,18 @@ tt.func private @memdesc_reinterpret(%arg0: !ttg.memdesc<4x1024xi8, #shared0, #t
23972397
}
23982398

23992399
}
2400+
2401+
// -----
2402+
2403+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
2404+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
2405+
// CHECK-LABEL: load_br
2406+
tt.func @load_br(%arg0: tensor<16x4x!tt.ptr<i8>, #blocked>) {
2407+
// CHECK: llvm.br
2408+
cf.br ^bb1(%arg0 : tensor<16x4x!tt.ptr<i8>, #blocked>)
2409+
^bb1(%arg1: tensor<16x4x!tt.ptr<i8>, #blocked>):
2410+
// CHECK: ld.global.b8
2411+
%0 = tt.load %arg1 : tensor<16x4x!tt.ptr<i8>, #blocked>
2412+
tt.return
2413+
}
2414+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ struct ConvertTritonGPUToLLVM
9595
RewritePatternSet funcPatterns(context);
9696
mlir::triton::populateFuncOpConversionPattern(
9797
typeConverter, funcPatterns, targetInfo, patternBenefitDefault);
98-
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
99-
funcPatterns);
10098
if (failed(
10199
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
102100
return signalPassFailure();
@@ -152,8 +150,6 @@ struct ConvertTritonGPUToLLVM
152150
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
153151
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
154152
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
155-
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
156-
patterns);
157153
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
158154
mlir::triton::populateViewOpToLLVMPatterns(typeConverter, patterns,
159155
benefit);
@@ -173,6 +169,18 @@ struct ConvertTritonGPUToLLVM
173169
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
174170
return signalPassFailure();
175171

172+
// Lower CF ops separately to avoid breaking analysis.
173+
TritonLLVMFunctionConversionTarget cfTarget(*context);
174+
cfTarget.markUnknownOpDynamicallyLegal([&](Operation *op) {
175+
return op->getDialect() !=
176+
context->getLoadedDialect<cf::ControlFlowDialect>();
177+
});
178+
RewritePatternSet cfPatterns(context);
179+
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
180+
cfPatterns);
181+
if (failed(applyPartialConversion(mod, cfTarget, std::move(cfPatterns))))
182+
return signalPassFailure();
183+
176184
// Fold CTAId when there is only 1 CTA.
177185
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
178186
if (numCTAs == 1) {

0 commit comments

Comments
 (0)