Skip to content

Commit c6eae40

Browse files
authored
[BACKEND] Clear out warp_specialize attr when peeling out outer loop (triton-lang#8073)
When we are peeling out the outer loop remove the warp specialization flag so that the specialized loop doesn't get warp specialized
1 parent b285609 commit c6eae40

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,9 @@ static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop,
10061006
newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits());
10071007
newInnerLoop.erase();
10081008

1009+
// Clear up the warp specialization attributes for the specialized loop.
1010+
newLoop->removeAttr(kWarpSpecializeAttrName);
1011+
10091012
// Move the loop nest into the `else` branch.
10101013
outerLoop.replaceAllUsesWith(ifOp.getResults());
10111014
Block *block = b.createBlock(&ifOp.getElseRegion());

test/TritonGPU/fuse-nested-loops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) {
473473
// CHECK: scf.if [[IS_ZERO]]
474474
// CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32
475475
// CHECK-NEXT: "prologue"
476-
// CHECK-NXET: }
476+
// CHECK-NXET: } {tt.flatten}
477477

478478
// CHECK: else
479479
// CHECK-COUNT-1: scf.for
@@ -487,7 +487,7 @@ tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) {
487487
"body"(%i, %j) : (i32, i32) -> ()
488488
scf.yield
489489
}
490-
} {tt.flatten}
490+
} {tt.flatten, tt.warp_specialize}
491491
tt.return
492492
}
493493

0 commit comments

Comments
 (0)