Skip to content

Commit e0c59f8

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick] [BACKEND] Clear out warp_specialize attr when peeling out outer loop (#8073) (#579)
Summary: Cherry-picked from upstream OAI repository. Original Commit: c6eae40 Original Author: Thomas Raoux Original Date: 2025-09-04 13:56:49 -0700 Original commit message: ``` [BACKEND] Clear out warp_specialize attr when peeling out outer loop (#8073) When we are peeling out the outer loop remove the warp specialization flag so that the specialized loop doesn't get warp specialized ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #579 Reviewed By: htyu Differential Revision: D86344798 Pulled By: agron911 fbshipit-source-id: 0c6a6f48bd829a13f4363c9294b4b7bf7cc54d7b
1 parent 4a6c0d3 commit e0c59f8

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,9 @@ static LogicalResult speculateInnerLoopLength(scf::ForOp outerLoop,
10061006
newInnerLoop.replaceAllUsesWith(newInnerLoop.getInits());
10071007
newInnerLoop.erase();
10081008

1009+
// Clear up the warp specialization attributes for the specialized loop.
1010+
newLoop->removeAttr(kWarpSpecializeAttrName);
1011+
10091012
// Move the loop nest into the `else` branch.
10101013
outerLoop.replaceAllUsesWith(ifOp.getResults());
10111014
Block *block = b.createBlock(&ifOp.getElseRegion());

test/TritonGPU/fuse-nested-loops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) {
473473
// CHECK: scf.if [[IS_ZERO]]
474474
// CHECK-NEXT: scf.for %{{.*}} = [[LB]] to [[UB]] step %c1_i32
475475
// CHECK-NEXT: "prologue"
476-
// CHECK-NXET: }
476+
// CHECK-NXET: } {tt.flatten}
477477

478478
// CHECK: else
479479
// CHECK-COUNT-1: scf.for
@@ -487,7 +487,7 @@ tt.func @fuse_attr_speculate(%lb: i32, %ub: i32) {
487487
"body"(%i, %j) : (i32, i32) -> ()
488488
scf.yield
489489
}
490-
} {tt.flatten}
490+
} {tt.flatten, tt.warp_specialize}
491491
tt.return
492492
}
493493

0 commit comments

Comments
 (0)