Skip to content

Commit 83a1f00

Browse files
authored
[Warp Spec] Always use 1 buffer for SSA partition dependencies (#7686)
There are realistic cases where this should be more than 1, but let's fix it to 1 for now since that's what happens in practice.
1 parent 39fbca8 commit 83a1f00

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct UseInfo {
4848
int UseInfo::getMaxUseDistance(const Partition &partition) {
4949
int maxDistance = 0;
5050
for (auto [usePartition, distance] : llvm::make_first_range(consumers)) {
51-
int dist = 2 + distance;
51+
int dist = 1 + distance;
5252
maxDistance = std::max(maxDistance, dist);
5353
}
5454
return maxDistance;

test/TritonGPU/rewrite-partition-dependencies.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
1010
// CHECK-LABEL: @two_consumers
1111
tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) {
1212
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
13-
// CHECK-NEXT: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
13+
// CHECK-NEXT: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
1414
// CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[ABUF]]
1515
scf.for %i = %lb to %ub step %step iter_args() -> () : i32 {
1616
%0 = "op_a"() {ttg.partition = 0} : () -> !ty
@@ -40,7 +40,7 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) {
4040
// CHECK-LABEL: @distance_one
4141
tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) {
4242
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
43-
// CHECK: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
43+
// CHECK: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
4444
// CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[ABUF]]
4545
%cst = arith.constant dense<0> : !ty
4646
// CHECK: scf.for [[IV:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[K:%.*]] = {{.*}})
@@ -63,9 +63,9 @@ tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) {
6363
}
6464

6565
tt.func @complex_case(%lb: i32, %ub: i32, %step: i32) {
66-
// CHECK: [[ABUF1:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
66+
// CHECK: [[ABUF1:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
6767
// CHECK-NEXT: [[AREF1:%.*]] = nvws.aref.create [[ABUF1]]
68-
// CHECK-NEXT: [[ABUF2:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
68+
// CHECK-NEXT: [[ABUF2:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
6969
// CHECK-NEXT: [[AREF2:%.*]] = nvws.aref.create [[ABUF2]]
7070
%cst = arith.constant dense<0> : !ty
7171
// CHECK: scf.for [[IV:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[K:%.*]] = {{.*}}, [[L:%.*]] = {{.*}})

0 commit comments

Comments
 (0)