Skip to content

Commit ac84d71

Browse files
authored
[AMD] Guard FoldTrueCmpI from tensors (#7281)
Check the input/output types of the `arith.cmpi` to guard cases where FoldTrueCmpI might fold tensors (e.g., masks).
1 parent d514243 commit ac84d71

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/TritonGPU/amd/amd-fold-true-cmpi.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
156156
// CHECK-NEXT: ttg.local_dealloc %[[VAL_24]] : !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable>
157157
// CHECK-NEXT: tt.return %[[VAL_58]] : tensor<128x128xf32, #[[$ATTR_2]]>
158158
// CHECK-NEXT: }
159+
160+
// -----
161+
162+
module attributes {"ttg.num-warps" = 4 : i32} {
163+
tt.func @dontfoldtensor() -> tensor<128xi1> {
164+
%t0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
165+
%t1 = tt.make_range {end = 257 : i32, start = 129 : i32} : tensor<128xi32>
166+
%cmp = arith.cmpi sgt, %t1, %t0 : tensor<128xi32>
167+
tt.return %cmp: tensor<128xi1>
168+
}
169+
}
170+
171+
// CHECK-LABEL: tt.func @dontfoldtensor
172+
// CHECK-NOT: arith.constant dense<true>
173+
// CHECK: %[[VAL_0:.*]] = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
174+
// CHECK: %[[VAL_1:.*]] = tt.make_range {end = 257 : i32, start = 129 : i32} : tensor<128xi32>
175+
// CHECK: %[[VAL_2:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_0]] : tensor<128xi32>
176+
// CHECK: tt.return %[[VAL_2]] : tensor<128xi1>
177+
// CHECK: }

third_party/amd/lib/Analysis/RangeAnalysis.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,8 @@ struct FoldTrueCmpIOp : OpRewritePattern<arith::CmpIOp> {
601601

602602
LogicalResult matchAndRewrite(arith::CmpIOp cmpOp,
603603
PatternRewriter &rewriter) const override {
604-
if (cmpIIsStaticallyTrue(*solver, cmpOp)) {
604+
if (llvm::isa<IntegerType, IndexType>(cmpOp.getType()) &&
605+
cmpIIsStaticallyTrue(*solver, cmpOp)) {
605606
if (failed(mlir::dataflow::maybeReplaceWithConstant(*solver, rewriter,
606607
cmpOp.getResult()))) {
607608
LDBG("failed to replace with constant op: " << cmpOp);

0 commit comments

Comments
 (0)