Skip to content

Commit 7135f4f

Browse files
authored
[Pipeliner] Fix assign latencies when local_alloc src is a block argument (#8628)
1 parent e89e9b2 commit 7135f4f

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/MMAv5PipelineUtility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ bool ttng::isOperandPipelineableBase(
6969
return true;
7070
}
7171
auto localAllocSrc = localAlloc.getSrc().getDefiningOp();
72-
if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(
73-
localAllocSrc)) {
72+
if (!isa_and_nonnull<tt::LoadOp, tt::DescriptorLoadOp,
73+
tt::DescriptorGatherOp>(localAllocSrc)) {
7474
return false;
7575
}
7676
foundDef = localAllocSrc;

test/TritonGPU/pipeline-assign-latencies.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,3 +1165,39 @@ module attributes {"ttg.num-warps" = 4 : i32} {
11651165
tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
11661166
}
11671167
}
1168+
1169+
// -----
1170+
1171+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
1172+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
1173+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
1174+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
1175+
1176+
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
1177+
// CHECK-LABEL: @tc_gen5_mma_alloc_block_arg
1178+
tt.func @tc_gen5_mma_alloc_block_arg(%lb : index, %ub : index, %step : index,
1179+
%A_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[1, 16]> : tensor<2xi32>},
1180+
%B_ptr: tensor<128x128x!tt.ptr<f16>, #blocked1> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[1, 16]> : tensor<2xi32>},
1181+
%acc_init : tensor<128x128xf32, #blocked1>) -> () {
1182+
%true = arith.constant true
1183+
%acc_tm = ttng.tmem_alloc %acc_init : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1184+
%zero = arith.constant dense<0.0> : tensor<128x128xf16, #blocked1>
1185+
// CHECK: ttng.tmem_alloc
1186+
// CHECK: scf.for
1187+
scf.for %iv = %lb to %ub step %step iter_args(%A = %zero, %B = %zero) -> (tensor<128x128xf16, #blocked1>, tensor<128x128xf16, #blocked1>) : index {
1188+
// Ensure this doesn't crash.
1189+
%A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>
1190+
%B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>
1191+
ttng.tmem_store %acc_init, %acc_tm, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1192+
// CHECK: ttng.tc_gen5_mma
1193+
ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
1194+
// CHECK: ttng.tmem_load
1195+
%acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1>
1196+
"use"(%acc_res) : (tensor<128x128xf32, #blocked1>) -> ()
1197+
%A_next = tt.load %A_ptr : tensor<128x128x!tt.ptr<f16>, #blocked1>
1198+
%B_next = tt.load %B_ptr : tensor<128x128x!tt.ptr<f16>, #blocked1>
1199+
scf.yield %A_next, %B_next : tensor<128x128xf16, #blocked1>, tensor<128x128xf16, #blocked1>
1200+
}
1201+
tt.return
1202+
}
1203+
}

0 commit comments

Comments
 (0)