Skip to content

Commit 988d388

Browse files
authored
[Pipeliner] Skip async_wait when there is no async_cp op (#6681)
minor cleanup
1 parent 4c4f871 commit 988d388

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,13 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
600600

601601
createTMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule);
602602

603+
bool hasAsyncLoads = false;
603604
for (auto [op, asyncLoad] : asyncLoads) {
604605
auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff];
605606
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
606607
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
607608
schedule);
609+
hasAsyncLoads = true;
608610
} else if (auto loadOp = dyn_cast<tt::DescriptorLoadOp>(op)) {
609611
createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
610612
asyncLoad.barrier, asyncLoad.waitOp, schedule);
@@ -628,10 +630,12 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
628630
// correct stages.
629631
scheduleDependencies(forOp, schedule);
630632

631-
// Insert sync point for any possibly outstanding loads after the loop. This
632-
// can happen as we speculatively execute loads in the loop.
633-
builder.setInsertionPointAfter(forOp);
634-
builder.create<ttg::AsyncWaitOp>(loc, ValueRange({}), 0);
633+
if (hasAsyncLoads) {
634+
// Insert sync point for any possibly outstanding loads after the loop. This
635+
// can happen as we speculatively execute loads in the loop.
636+
builder.setInsertionPointAfter(forOp);
637+
builder.create<ttg::AsyncWaitOp>(loc, ValueRange({}), 0);
638+
}
635639

636640
// Make sure all ops have attributes.
637641
for (Operation &op : forOp.getBody()->without_terminator()) {

test/TritonGPU/samples/descriptor-matmul-pipeline.mlir

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,18 @@
112112
// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_81]], %[[VAL_84]], %[[VAL_72]], %[[VAL_74]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32
113113
// CHECK: }
114114
// CHECK: %[[VAL_90:.*]] = ttng.warp_group_dot_wait %[[VAL_91:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>
115-
// CHECK: %[[VAL_92:.*]] = ttg.async_wait {num = 0 : i32}
116-
// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
115+
// CHECK: %[[VAL_92:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
116+
// CHECK: ttng.inval_barrier %[[VAL_92]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
117+
// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
117118
// CHECK: ttng.inval_barrier %[[VAL_93]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
118-
// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
119+
// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
119120
// CHECK: ttng.inval_barrier %[[VAL_94]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
120-
// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_7]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
121-
// CHECK: ttng.inval_barrier %[[VAL_95]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3>
122121
// CHECK: ttg.local_dealloc %[[VAL_45]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable>
123122
// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable>
124123
// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable>
125-
// CHECK: %[[VAL_96:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]>
126-
// CHECK: %[[VAL_97:.*]] = ttg.convert_layout %[[VAL_96]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]>
127-
// CHECK: tt.descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_97]] : !tt.tensordesc<tensor<128x256xf16, #[[$ATTR_2]]>>, tensor<128x256xf16, #[[$ATTR_0]]>
124+
// CHECK: %[[VAL_95:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]>
125+
// CHECK: %[[VAL_96:.*]] = ttg.convert_layout %[[VAL_95]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]>
126+
// CHECK: tt.descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_96]] : !tt.tensordesc<tensor<128x256xf16, #[[$ATTR_2]]>>, tensor<128x256xf16, #[[$ATTR_0]]>
128127
// CHECK: tt.return
129128
// CHECK: }
130129
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

test/TritonGPU/samples/simulated-grouped-gemm.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,12 @@
257257
// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32}
258258
// CHECK: ttg.local_dealloc %[[VAL_116]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable>
259259
// CHECK: %[[VAL_213:.*]] = ttng.warp_group_dot_wait %[[VAL_214:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>
260-
// CHECK: %[[VAL_215:.*]] = ttg.async_wait {num = 0 : i32}
261-
// CHECK: %[[VAL_216:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
260+
// CHECK: %[[VAL_215:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_13]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
261+
// CHECK: ttng.inval_barrier %[[VAL_215]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
262+
// CHECK: %[[VAL_216:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
262263
// CHECK: ttng.inval_barrier %[[VAL_216]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
263-
// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_10]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
264+
// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
264265
// CHECK: ttng.inval_barrier %[[VAL_217]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
265-
// CHECK: %[[VAL_218:.*]] = ttg.memdesc_subview %[[VAL_48]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
266-
// CHECK: ttng.inval_barrier %[[VAL_218]] : !ttg.memdesc<1xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable, 3>
267266
// CHECK: ttg.local_dealloc %[[VAL_48]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #[[$ATTR_4]], mutable>
268267
// CHECK: ttg.local_dealloc %[[VAL_47]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable>
269268
// CHECK: ttg.local_dealloc %[[VAL_46]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #[[$ATTR_4]], mutable>

0 commit comments

Comments
 (0)