Skip to content

Commit 6fcbac9

Browse files
authored
[Warp Specialization] Fix WAR async+generic proxy for warp spec (#7278)
@ThomasRaoux fixed this for synchronous code but warp specialization needs to generate this when it splits the loop.
1 parent 187ea27 commit 6fcbac9

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,8 @@ LogicalResult PipelinedLoadGroup::lowerLoads(WarpSchedule &schedule,
442442
StageCluster userStageCluster = getStageCluster(loadBeforeOp);
443443
Value loaded = b.createInto<LocalLoadOp>(*partition, userStageCluster,
444444
load.type, view);
445+
b.createInto<ttng::FenceAsyncSharedOp>(*partition, userStageCluster,
446+
/*bCluster=*/false);
445447
for (OpOperand *use : uses)
446448
use->set(loaded);
447449
}

test/TritonGPU/load-mma-specialization.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@ tt.func @specialize_load_only(%desc: !tt.tensordesc<tensor<128x64xf16, #shared>>
10241024
scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
10251025
// CHECK: wait_barrier {{.*}} {ttg.partition = 0 : i32}
10261026
// CHECK-NEXT: local_load {{.*}} {ttg.partition = 0 : i32}
1027+
// CHECK-NEXT: fence_async_shared {{.*}}partition = 0
10271028
// CHECK-NEXT: arrive_barrier {{.*}} {ttg.partition = 0 : i32}
10281029
%val = tt.descriptor_load %desc[%i, %i] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #oper_layout>
10291030
"use"(%val) : (tensor<128x64xf16, #oper_layout>) -> ()
@@ -1078,6 +1079,7 @@ tt.func @specialize_mma_only(%rhs_desc: !tt.tensordesc<tensor<64x128xf16, #share
10781079
// CHECK-NEXT: [[LOADED:%.*]], %{{.*}} = ttng.tmem_load [[ACC_TMEM:%.*]][]
10791080
// CHECK: wait_barrier
10801081
// CHECK-NEXT: local_load
1082+
// CHECK-NEXT: fence_async_shared {{.*}}partition = 0
10811083
// CHECK-NEXT: arrive_barrier
10821084
// CHECK-NEXT: [[RESULTS:%.*]]:2 = "some_producer"
10831085
%rhs_reg, %next_acc = "some_producer"(%loaded, %acc) : (tensor<64x128xf16, #oper_layout>, tensor<128x128xf32, #acc_layout>) -> (tensor<64x128xf16, #oper_layout>, tensor<128x128xf32, #acc_layout>)
@@ -1187,6 +1189,7 @@ tt.func @store_mma_load(
11871189

11881190
// CHECK-NEXT: wait_barrier [[LOAD_READY_BAR]], {{.*}}partition = 0
11891191
// CHECK-NEXT: [[LHS:%.*]] = ttg.local_load [[LOAD_BUF]] {ttg.partition = 0 : i32}
1192+
// CHECK-NEXT: fence_async_shared {{.*}}partition = 0
11901193
// CHECK-NEXT: arrive_barrier [[LOAD_EMPTY_BAR]], {{.*}}partition = 0
11911194
// CHECK-NEXT: [[LHS_OP:%.*]] = arith.addf [[LHS]], [[LHS]] {ttg.partition = 0 : i32}
11921195
// CHECK-NEXT: local_store [[LHS_OP]], [[LHS_SHARED]] {ttg.partition = 0 : i32}
@@ -1234,6 +1237,7 @@ tt.func @local_alloc_into_mma(
12341237

12351238
// CHECK: wait_barrier [[LOAD_READY_BAR]], {{.*}}partition = 0
12361239
// CHECK-NEXT: [[RHS_REG:%.*]] = ttg.local_load {{.*}}partition = 0
1240+
// CHECK-NEXT: fence_async_shared {{.*}}partition = 0
12371241
// CHECK-NEXT: arrive_barrier
12381242
// CHECK-NEXT: [[RHS_REG_MOD:%.*]] = arith.addf [[RHS_REG]], [[RHS_REG]] {ttg.partition = 0 : i32}
12391243
// CHECK-NEXT: wait_barrier [[MMA_OPER_BAR:%.*]], %arg{{.*}}partition = 0

0 commit comments

Comments
 (0)