Skip to content

Commit 3200f5a

Browse files
authored
[WS] Recognize warp-specialized nested loops in AssignLatencies (#8451)
There is a WS-specific logic in `AssignLatencies`, which changes the latency annotation for `tc_gen5_mma` in attention. It looks for mma ops whose parent for op is annotated with `tt.warp_specialize`, so it misses cases with nested loops where the outer loop has `tt.warp_specialize` while the inner loop is intended to be SWP-ed. This PR fixes that so that the inner loop of nested-loop persistent attention is SWP-ed identically to non-persistent attention.
1 parent f5d04d5 commit 3200f5a

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ class AssignMMALatencies {
194194
// overlap. WS does not have this problem because the MMA is placed in
195195
// a different partition than the MMA, so we can correctly set the
196196
// latency.
197-
if (forOp->hasAttr(kWarpSpecializeAttrName)) {
197+
if (isWarpSpecialized(forOp)) {
198198
if (ttng::hasAccReadModifyWrite(mma, forOp))
199199
opLatency.erase(&op); // can't pipeline the MMA
200200
else
@@ -217,6 +217,17 @@ class AssignMMALatencies {
217217
}
218218
return false;
219219
}
220+
221+
bool isWarpSpecialized(scf::ForOp forOp) {
222+
scf::ForOp current = forOp;
223+
do {
224+
if (current->hasAttr(kWarpSpecializeAttrName)) {
225+
return true;
226+
}
227+
current = current->getParentOfType<scf::ForOp>();
228+
} while (current);
229+
return false;
230+
};
220231
};
221232

222233
// Discover operations that should become async and assign latencies to them

test/TritonGPU/pipeline-assign-latencies.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,3 +1088,62 @@ tt.func public @attention_forward(
10881088
}
10891089

10901090
}
1091+
1092+
// -----
1093+
1094+
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
1095+
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1096+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
1097+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
1098+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>
1099+
#smem = #ttg.shared_memory
1100+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
1101+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
1102+
tt.func public @attention_persistent_inner_loop_kernel(%desc_q: !tt.tensordesc<tensor<128x128xf16, #shared>>, %desc_q_0: i32, %desc_q_1: i32, %desc_q_2: i64, %desc_q_3: i64, %desc_k: !tt.tensordesc<tensor<128x128xf16, #shared>>, %desc_k_4: i32, %desc_k_5: i32, %desc_k_6: i64, %desc_k_7: i64, %desc_v: !tt.tensordesc<tensor<128x128xf16, #shared>>, %desc_v_8: i32, %desc_v_9: i32, %desc_v_10: i64, %desc_v_11: i64, %desc_acc: !tt.tensordesc<tensor<128x128xf16, #shared>>, %desc_acc_12: i32, %desc_acc_13: i32, %desc_acc_14: i64, %desc_acc_15: i64, %l_i_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %m_i_ptr: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %qk_scale: f32) attributes {noinline = false} {
1103+
%false = arith.constant false
1104+
%true = arith.constant true
1105+
%c1_i32 = arith.constant 1 : i32
1106+
%c0_i32 = arith.constant 0 : i32
1107+
%c128_i32 = arith.constant 128 : i32
1108+
%cst = arith.constant dense<1.000000e+00> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1109+
%cst_16 = arith.constant dense<0xFF800000> : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
1110+
%prog_id = tt.get_program_id x : i32
1111+
%num_sm = tt.get_num_programs x : i32
1112+
%num_tiles = arith.divsi %M, %c128_i32 : i32
1113+
%tiles_per_sm = arith.divsi %num_tiles, %num_sm : i32
1114+
%tile_idx = scf.for %_ = %c0_i32 to %tiles_per_sm step %c1_i32 iter_args(%tile_idx_20 = %prog_id) -> (i32) : i32 {
1115+
%off_m = arith.muli %tile_idx_20, %c128_i32 : i32
1116+
%q = tt.descriptor_load %desc_q[%off_m, %c0_i32] : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked2>
1117+
%q_21 = ttg.local_alloc %q : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
1118+
%qk_22, %qk_23 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
1119+
%acc, %acc_24 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
1120+
%acc_26:4 = scf.for %acc_30 = %c0_i32 to %N step %c128_i32 iter_args(%arg28 = %cst_16, %arg29 = %cst, %qk_31 = %qk_23, %acc_32 = %acc_24) -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token) : i32 {
1121+
// CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32}
1122+
%k = tt.descriptor_load %desc_k[%acc_30, %c0_i32] : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked2>
1123+
%k_33 = ttg.local_alloc %k : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
1124+
%k_34 = ttg.memdesc_trans %k_33 {order = array<i32: 1, 0>} : !ttg.memdesc<128x128xf16, #shared, #smem> -> !ttg.memdesc<128x128xf16, #shared1, #smem>
1125+
// CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.latency = 2 : i32, tt.self_latency = 1 : i32}
1126+
%qk_35 = ttng.tc_gen5_mma %q_21, %k_34, %qk_22[%qk_31], %false, %true : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1127+
%qk_36, %qk_37 = ttng.tmem_load %qk_22[%qk_35] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
1128+
1129+
%acc_47, %p, %next_l_i, %row_max = "softmax_work"(%qk_36, %arg29, %arg28) : (tensor<128x128xf32, #blocked>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) -> (tensor<128x128xf32, #blocked>, tensor<128x128xf16, #blocked>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>)
1130+
1131+
%acc_48, %acc_49 = ttng.tmem_load %acc[%acc_32] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
1132+
%acc_50 = arith.mulf %acc_48, %acc_47 : tensor<128x128xf32, #blocked>
1133+
%p_53 = ttg.local_alloc %p : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
1134+
%acc_54 = ttng.tmem_store %acc_50, %acc[%acc_49], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1135+
// CHECK: tt.descriptor_load {{.*}} {tt.latency = 2 : i32}
1136+
%v = tt.descriptor_load %desc_v[%acc_30, %c0_i32] : !tt.tensordesc<tensor<128x128xf16, #shared>> -> tensor<128x128xf16, #blocked2>
1137+
%v_51 = ttg.local_alloc %v : (tensor<128x128xf16, #blocked2>) -> !ttg.memdesc<128x128xf16, #shared, #smem>
1138+
1139+
// CHECK: ttng.tc_gen5_mma {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {tt.self_latency = 1 : i32}
1140+
%acc_55 = ttng.tc_gen5_mma %p_53, %v_51, %acc[%acc_54], %true, %true : !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf16, #shared, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1141+
1142+
scf.yield %row_max, %next_l_i, %qk_37, %acc_55 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, !ttg.async.token, !ttg.async.token
1143+
}
1144+
%tile_idx_29 = arith.addi %tile_idx_20, %num_sm : i32
1145+
scf.yield %tile_idx_29 : i32
1146+
} {tt.num_stages = 3 : i32, tt.warp_specialize}
1147+
tt.return
1148+
}
1149+
}

0 commit comments

Comments
 (0)