Skip to content

Commit da3ab2a

Browse files
authored
[codegen] Use Python identifier as prefix for IR SSA names (#7521)
Quite some time ago I implemented [`-mlir-use-nameloc-as-prefix` upstream](llvm/llvm-project#119996). I actually implemented this because I was pulling my hair out trying to debug Triton pipelines but I never got around to plumbing it all the way through. So here's the plumbing. The way this works is like this ```python # demo.py @triton.jit def _kernel(src, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) offset = pid * BLOCK_SIZE offsets = offset + tl.arange(0, BLOCK_SIZE) load_src_store_dst = src + offsets mask = offsets < N x_plus_1 = tl.load(load_src_store_dst, mask=mask) + 1 tl.store(load_src_store_dst, x_plus_1, mask=mask) # shell MLIR_ENABLE_DUMP=1 python demo.py ``` will give you dumps like this: ```mlir module { tt.func public @_kernel(%src: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} { %cst = arith.constant dense<1.000000e+00> : tensor<16xf32> %c16_i32 = arith.constant 16 : i32 %pid = tt.get_program_id x : i32 %offset = arith.muli %pid, %c16_i32 : i32 %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> %load_src_store_dst = tt.splat %src : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>> %load_src_store_dst_2 = tt.addptr %load_src_store_dst, %offsets_1 : tensor<16x!tt.ptr<f32>>, tensor<16xi32> %mask = tt.splat %N : i32 -> tensor<16xi32> %mask_3 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> %x_plus_1 = tt.load %load_src_store_dst_2, %mask_3 : tensor<16x!tt.ptr<f32>> %x_plus_1_4 = arith.addf %x_plus_1, %cst : tensor<16xf32> tt.store %load_src_store_dst_2, %x_plus_1_4, %mask_3 : tensor<16x!tt.ptr<f32>> tt.return } } ``` Notice, the SSA name (roughly) correspond to the Python identifiers (including func args `%src, %N`). Note, the reason we have ```mlir %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32> %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> ``` is because the way it's plumbed is the "target" of the assignment determines (contextually) the SSA names of all of the intermediate values of the rhs. While this is subject to bike-shedding tbh there's not really another way to do it (I tried...). **Furthermore**, because `NameLoc` attributes are just `Location` attributes these names will persist/be propagated through passes (assuming the passes correctly/adequately propagate): ```mlir // -----// IR Dump After TritonAMDGPUConvertToBufferOps (tritonamdgpu-convert-buffer-ops) ('builtin.module' operation) //----- // #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { tt.func public @_kernel(%src: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} { %cst = arith.constant dense<1.000000e+00> : tensor<16xf32, #blocked> %c16_i32 = arith.constant 16 : i32 %pid = tt.get_program_id x : i32 %offset = arith.muli %pid, %c16_i32 : i32 %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> %offsets_0 = tt.splat %offset : i32 -> tensor<16xi32, #blocked> %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32, #blocked> %load_src_store_dst = tt.addptr %src, %offset : !tt.ptr<f32>, i32 %mask = tt.splat %N : i32 -> tensor<16xi32, #blocked> %mask_2 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32, #blocked> %x_plus_1 = amdgpu.buffer_load %load_src_store_dst[%offsets], %mask_2 : tensor<16xf32, #blocked> %x_plus_1_3 = arith.addf %x_plus_1, %cst : tensor<16xf32, #blocked> amdgpu.buffer_store %x_plus_1_3, %load_src_store_dst[%offsets], %mask_2 : tensor<16xf32, #blocked> tt.return } } ``` Notice `%x_plus_1 = amdgpu.buffer_load` keeps the SSA name from `%x_plus_1 = tt.load`. This is the *real* value prop (at least for me). And on down: ```mlir // -----// IR Dump After ConvertBuiltinFuncToLLVM (convert-builtin-func-to-llvm) ('builtin.module' operation) //----- // module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> llvm.func @_kernel(%src: !llvm.ptr<1> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %N: i32 {tt.divisibility = 16 : i32} , %arg2: !llvm.ptr<1> ) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 128>} { %0 = llvm.mlir.undef : vector<1xf32> %1 = llvm.mlir.constant(3 : i32) : i32 %2 = llvm.mlir.constant(true) : i1 %3 = llvm.mlir.constant(4 : i32) : i32 %4 = llvm.mlir.constant(-2147483648 : i32) : i32 %5 = llvm.mlir.constant(2147483646 : i32) : i32 %6 = llvm.mlir.constant(822243328 : i32) : i32 %7 = llvm.mlir.constant(0 : i16) : i16 %8 = llvm.mlir.constant(15 : i32) : i32 %9 = llvm.mlir.constant(5 : i32) : i32 %10 = llvm.mlir.constant(0 : i32) : i32 %11 = llvm.mlir.constant(32 : i32) : i32 %12 = llvm.mlir.constant(127 : i32) : i32 %13 = llvm.mlir.constant(0 : index) : i32 %14 = llvm.mlir.constant(16 : i32) : i32 %15 = llvm.mlir.constant(1.000000e+00 : f32) : f32 %pid = rocdl.workgroup.id.x : i32 %offset = llvm.mul %pid, %14 : i32 %offsets = rocdl.workitem.id.x : i32 %offsets_0 = llvm.and %offsets, %12 : i32 %offsets_1 = llvm.urem %offsets_0, %11 : i32 %offsets_2 = llvm.udiv %offsets_0, %11 : i32 %offsets_3 = llvm.shl %offsets_1, %10 : i32 %offsets_4 = llvm.or %10, %offsets_3 : i32 %offsets_5 = llvm.shl %offsets_2, %9 : i32 %offsets_6 = llvm.or %offsets_4, %offsets_5 : i32 %offsets_7 = llvm.and %offsets_6, %8 : i32 %offsets_8 = llvm.lshr %offsets_7, %10 : i32 %offsets_9 = llvm.xor %10, %offsets_8 : i32 %offsets_10 = llvm.xor %10, %offsets_9 : i32 %offsets_11 = llvm.xor %offsets_10, %10 : i32 %offsets_12 = llvm.add %offsets_11, %13 : i32 %offsets_13 = llvm.add %offset, %offsets_12 : i32 %load_src_store_dst = llvm.getelementptr %src[%offset] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 %mask = llvm.icmp "slt" %offsets_13, %N : i32 %x_plus_1 = rocdl.make.buffer.rsrc %load_src_store_dst, %7, %5, %6 : <1> to <8> %x_plus_1_14 = llvm.mul %offsets_12, %3 : i32 %x_plus_1_15 = llvm.select %mask, %x_plus_1_14, %4 : i1, i32 %x_plus_1_16 = rocdl.raw.ptr.buffer.load %x_plus_1, %x_plus_1_15, %10, %10 : f32 %x_plus_1_17 = llvm.bitcast %x_plus_1_16 : f32 to vector<1xf32> %x_plus_1_18 = llvm.extractelement %x_plus_1_17[%13 : i32] : vector<1xf32> %x_plus_1_19 = llvm.fadd %x_plus_1_18, %15 : f32 %16 = llvm.and %offsets_1, %14 : i32 %17 = llvm.icmp "eq" %16, %10 : i32 %18 = llvm.and %2, %17 : i1 %19 = llvm.and %offsets_2, %1 : i32 %20 = llvm.icmp "eq" %19, %10 : i32 %21 = llvm.and %18, %20 : i1 %22 = llvm.and %21, %mask : i1 %23 = llvm.insertelement %x_plus_1_19, %0[%10 : i32] : vector<1xf32> %24 = llvm.bitcast %23 : vector<1xf32> to f32 %25 = llvm.select %22, %x_plus_1_14, %4 : i1, i32 rocdl.raw.ptr.buffer.store %24, %x_plus_1, %25, %10, %10 : f32 llvm.return } } ``` Note, the "explosion" in e.g. `%offsets_*` is due to the choices made in the passes themselves, not the flag/plumbing (i.e., location propagation).
1 parent 388149b commit da3ab2a

File tree

6 files changed

+297
-54
lines changed

6 files changed

+297
-54
lines changed

python/src/ir.cc

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ ReproducerStreamFactory makeConsoleReproducer() {
152152
OpPrintingFlags getOpPrintingFlags() {
153153
auto printingFlags = OpPrintingFlags();
154154
printingFlags.enableDebugInfo();
155+
printingFlags.printNameLocAsPrefix(true);
155156
return printingFlags;
156157
}
157158

@@ -372,11 +373,15 @@ void init_triton_ir(py::module &&m) {
372373
self.replaceAllUsesWith(newValue);
373374
})
374375
.def("get_type", &Value::getType)
375-
.def("id", [](Value &self) {
376-
// The Value is identified by and compared with
377-
// other Values via the underlying ValueImpl
378-
return (uint64_t)self.getImpl();
379-
});
376+
.def("id",
377+
[](Value &self) {
378+
// The Value is identified by and compared with
379+
// other Values via the underlying ValueImpl
380+
return (uint64_t)self.getImpl();
381+
})
382+
.def("set_loc",
383+
[](Value &self, Location loc) { return self.setLoc(loc); })
384+
.def("get_loc", [](Value &self) { return self.getLoc(); });
380385

381386
py::class_<OpResult, Value>(m, "op_result", py::module_local());
382387

@@ -929,6 +934,28 @@ void init_triton_ir(py::module &&m) {
929934
// locs
930935
.def("set_loc",
931936
[](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); })
937+
.def("set_loc",
938+
[](TritonOpBuilder &self, std::string name) {
939+
auto nameAttr = StringAttr::get(self.getContext(), name);
940+
auto loc = NameLoc::get(nameAttr);
941+
self.setLastLoc(loc);
942+
})
943+
.def("create_loc",
944+
[](TritonOpBuilder &self, const std::string &fileName, int line,
945+
int column) -> Location {
946+
return mlir::FileLineColLoc::get(self.getContext(), fileName, line,
947+
column);
948+
})
949+
.def(
950+
"create_name_loc",
951+
[](TritonOpBuilder &self, std::string name,
952+
std::optional<Location> childLoc) -> Location {
953+
auto nameAttr = StringAttr::get(self.getContext(), name);
954+
if (childLoc)
955+
return NameLoc::get(nameAttr, *childLoc);
956+
return NameLoc::get(nameAttr);
957+
},
958+
py::arg("name"), py::arg("child_loc") = py::none())
932959
.def("set_loc",
933960
[](TritonOpBuilder &self, const std::string &fileName, int line,
934961
int column) { self.setLastLoc(fileName, line, column); })

python/test/gluon/test_frontend.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def tensor_memory_kernel(layout: ttgl.constexpr, tmem_layout: ttgl.constexpr):
165165
slice2 = mem.slice(YBLOCK // 2, YBLOCK // 2) # noqa: F841
166166

167167
buffers = ttgl.nvidia.blackwell.allocate_tensor_memory(ttgl.float32, [2, XBLOCK, YBLOCK], tmem_layout)
168-
for i in range(2):
169-
buffers.index(i).load(layout)
168+
for ivar in range(2):
169+
buffers.index(ivar).load(layout)
170170

171171

172172
@pytest.mark.skipif(not is_blackwell(), reason="Requires blackwell tensor cores")
@@ -200,9 +200,9 @@ def test_tensor_memory(fresh_knobs):
200200
%3 = arith.bitcast %c2_i32 : i32 to i32 loc(#loc)
201201
%4 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc)
202202
%5 = ub.poison : i32 loc(#loc)
203-
scf.for %arg0 = %2 to %3 step %4 : i32 {
203+
scf.for %ivar = %2 to %3 step %4 : i32 {
204204
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
205-
%6 = ttg.memdesc_subview %result_2[%arg0, %c0_i32_4, %c0_i32_4] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc)
205+
%6 = ttg.memdesc_subview %result_2[%ivar, %c0_i32_4, %c0_i32_4] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> loc(#loc)
206206
%result_5 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> loc(#loc)
207207
} loc(#loc)
208208
tt.return loc(#loc)
@@ -257,8 +257,8 @@ def test_shared_memory_subview(fresh_knobs):
257257
@gluon.jit
258258
def shared_memory_index_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, smem_layout: ttgl.constexpr):
259259
smem = ttgl.allocate_shared_memory(ttgl.int32, [4, XBLOCK], smem_layout)
260-
for i in range(4):
261-
smem.index(i).load(layout)
260+
for ivar in range(4):
261+
smem.index(ivar).load(layout)
262262

263263

264264
@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA")
@@ -283,9 +283,9 @@ def test_shared_memory_index(fresh_knobs):
283283
%2 = arith.bitcast %c4_i32 : i32 to i32 loc(#loc)
284284
%3 = arith.bitcast %c1_i32 : i32 to i32 loc(#loc)
285285
%4 = ub.poison : i32 loc(#loc)
286-
scf.for %arg0 = %1 to %2 step %3 : i32 {
286+
scf.for %ivar = %1 to %2 step %3 : i32 {
287287
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
288-
%5 = ttg.memdesc_subview %0[%arg0, %c0_i32_0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc)
288+
%5 = ttg.memdesc_subview %0[%ivar, %c0_i32_0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> loc(#loc)
289289
%6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> -> tensor<256xi32, #blocked> loc(#loc)
290290
} loc(#loc)
291291
tt.return loc(#loc)
@@ -676,32 +676,33 @@ def test_async_tma(fresh_knobs):
676676
h = async_tma_kernel.warmup(input_desc, XBLOCK, grid=(1, ), num_warps=4)
677677
expecttest.assert_expected_inline(
678678
anonymize_ir(h.asm["source"]), """\
679-
#loc = loc(unknown)
679+
#loc1 = loc("input_desc")
680680
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
681681
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
682682
#smem = #ttg.shared_memory
683683
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
684-
tt.func public @async_tma_kernel(%arg0: !tt.tensordesc<tensor<128x128xf16, #shared>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
684+
tt.func public @async_tma_kernel(%input_desc: !tt.tensordesc<tensor<128x128xf16, #shared>> loc("input_desc"), %input_desc_0: i32 loc("input_desc"), %input_desc_1: i32 loc("input_desc"), %input_desc_2: i64 loc("input_desc"), %input_desc_3: i64 loc("input_desc")) attributes {noinline = false} {
685685
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
686686
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
687687
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
688688
%c0_i32 = arith.constant 0 : i32 loc(#loc)
689-
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
689+
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
690690
%true = arith.constant true loc(#loc)
691-
ttng.async_tma_copy_global_to_local %arg0[%c0_i32, %c0_i32_0] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
692-
%true_1 = arith.constant true loc(#loc)
693-
ttng.barrier_expect %1, 32768, %true_1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
694-
%c0_i32_2 = arith.constant 0 : i32 loc(#loc)
695-
%true_3 = arith.constant true loc(#loc)
696-
ttng.wait_barrier %1, %c0_i32_2, %true_3 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
691+
ttng.async_tma_copy_global_to_local %input_desc[%c0_i32, %c0_i32_4] %0, %1, %true : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
692+
%true_5 = arith.constant true loc(#loc)
693+
ttng.barrier_expect %1, 32768, %true_5 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
694+
%c0_i32_6 = arith.constant 0 : i32 loc(#loc)
695+
%true_7 = arith.constant true loc(#loc)
696+
ttng.wait_barrier %1, %c0_i32_6, %true_7 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
697697
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
698-
%c0_i32_4 = arith.constant 0 : i32 loc(#loc)
699-
%c0_i32_5 = arith.constant 0 : i32 loc(#loc)
700-
ttng.async_tma_copy_local_to_global %arg0[%c0_i32_4, %c0_i32_5] %0 : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
698+
%c0_i32_8 = arith.constant 0 : i32 loc(#loc)
699+
%c0_i32_9 = arith.constant 0 : i32 loc(#loc)
700+
ttng.async_tma_copy_local_to_global %input_desc[%c0_i32_8, %c0_i32_9] %0 : !tt.tensordesc<tensor<128x128xf16, #shared>>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
701701
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
702702
tt.return loc(#loc)
703703
} loc(#loc)
704704
} loc(#loc)
705+
#loc = loc(unknown)
705706
""")
706707

707708

@@ -736,31 +737,32 @@ def test_async_tma_blackwell(fresh_knobs):
736737
expecttest.assert_expected_inline(
737738
anonymize_ir(h.asm["source"]), """\
738739
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
739-
#loc = loc(unknown)
740+
#loc1 = loc("input_desc")
740741
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
741742
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
742743
#smem = #ttg.shared_memory
743744
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
744-
tt.func public @async_tma_blackwell_kernel(%arg0: !tt.tensordesc<tensor<1x128xf16, #shared>> loc(unknown), %arg1: i32 loc(unknown), %arg2: i32 loc(unknown), %arg3: i64 loc(unknown), %arg4: i64 loc(unknown)) attributes {noinline = false} {
745+
tt.func public @async_tma_blackwell_kernel(%input_desc: !tt.tensordesc<tensor<1x128xf16, #shared>> loc("input_desc"), %input_desc_0: i32 loc("input_desc"), %input_desc_1: i32 loc("input_desc"), %input_desc_2: i64 loc("input_desc"), %input_desc_3: i64 loc("input_desc")) attributes {noinline = false} {
745746
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
746747
%1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
747748
ttng.init_barrier %1, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
748749
%2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
749750
%true = arith.constant true loc(#loc)
750751
%c0_i32 = arith.constant 0 : i32 loc(#loc)
751-
ttng.async_tma_gather %arg0[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
752-
%true_0 = arith.constant true loc(#loc)
753-
ttng.barrier_expect %1, 32768, %true_0 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
754-
%c0_i32_1 = arith.constant 0 : i32 loc(#loc)
755-
%true_2 = arith.constant true loc(#loc)
756-
ttng.wait_barrier %1, %c0_i32_1, %true_2 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
752+
ttng.async_tma_gather %input_desc[%2, %c0_i32] %0, %1, %true : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<1xi64, #shared1, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, i1 loc(#loc)
753+
%true_4 = arith.constant true loc(#loc)
754+
ttng.barrier_expect %1, 32768, %true_4 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
755+
%c0_i32_5 = arith.constant 0 : i32 loc(#loc)
756+
%true_6 = arith.constant true loc(#loc)
757+
ttng.wait_barrier %1, %c0_i32_5, %true_6 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
757758
ttng.inval_barrier %1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc)
758-
%c0_i32_3 = arith.constant 0 : i32 loc(#loc)
759-
ttng.async_tma_scatter %arg0[%2, %c0_i32_3] %0 : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
759+
%c0_i32_7 = arith.constant 0 : i32 loc(#loc)
760+
ttng.async_tma_scatter %input_desc[%2, %c0_i32_7] %0 : !tt.tensordesc<tensor<1x128xf16, #shared>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
760761
ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc)
761762
tt.return loc(#loc)
762763
} loc(#loc)
763764
} loc(#loc)
765+
#loc = loc(unknown)
764766
""")
765767

766768

@@ -972,8 +974,9 @@ def test_reduce(fresh_knobs):
972974
anonymize_ir(h.asm["ttgir"]), """\
973975
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
974976
#loc = loc(unknown)
977+
#loc1 = loc("out")
975978
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
976-
tt.func public @reduce_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
979+
tt.func public @reduce_kernel(%out: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("out")) attributes {noinline = false} {
977980
%cst = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
978981
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
979982
%0 = "tt.reduce"(%cst_0) <{axis = 0 : i32}> ({
@@ -1003,7 +1006,7 @@ def test_reduce(fresh_knobs):
10031006
%7 = arith.addf %6, %4#0 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
10041007
%8 = arith.addf %7, %4#1 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
10051008
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1006-
%10 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
1009+
%10 = tt.splat %out : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
10071010
%11 = tt.addptr %10, %9 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
10081011
tt.store %11, %8 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
10091012
tt.return loc(#loc)
@@ -1202,16 +1205,17 @@ def test_async_copy(fresh_knobs):
12021205
expecttest.assert_expected_inline(
12031206
anonymize_ir(h.asm["ttgir"]), """\
12041207
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
1205-
#loc = loc(unknown)
1208+
#loc1 = loc("inp")
1209+
#loc2 = loc("xnumel")
12061210
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
12071211
#smem = #ttg.shared_memory
12081212
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
1209-
tt.func public @async_copy_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc(unknown), %arg1: i32 loc(unknown)) attributes {noinline = false} {
1213+
tt.func public @async_copy_kernel(%inp: !tt.ptr<f16> {tt.divisibility = 16 : i32} loc("inp"), %xnumel: i32 loc("xnumel")) attributes {noinline = false} {
12101214
%0 = ttg.local_alloc : () -> !ttg.memdesc<128xf16, #shared, #smem, mutable> loc(#loc)
12111215
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked> loc(#loc)
1212-
%2 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked> loc(#loc)
1216+
%2 = tt.splat %xnumel : i32 -> tensor<128xi32, #blocked> loc(#loc)
12131217
%3 = arith.cmpi slt, %1, %2 {tt.constancy = dense<2> : tensor<1xi32>} : tensor<128xi32, #blocked> loc(#loc)
1214-
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
1218+
%4 = tt.splat %inp : !tt.ptr<f16> -> tensor<128x!tt.ptr<f16>, #blocked> loc(#loc)
12151219
%5 = tt.addptr %4, %1 : tensor<128x!tt.ptr<f16>, #blocked>, tensor<128xi32, #blocked> loc(#loc)
12161220
%6 = ttg.async_copy_global_to_local %5, %0 : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
12171221
%7 = ttg.async_copy_global_to_local %5, %0 mask %3 cacheModifier = ca evictionPolicy = evict_last {isVolatile = true} : tensor<128x!tt.ptr<f16>, #blocked> -> <128xf16, #shared, #smem, mutable> loc(#loc)
@@ -1223,6 +1227,7 @@ def test_async_copy(fresh_knobs):
12231227
tt.return loc(#loc)
12241228
} loc(#loc)
12251229
} loc(#loc)
1230+
#loc = loc(unknown)
12261231
""")
12271232

12281233

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ def kernel(a, b):
1616
A = torch.randn(1024, device=device)
1717
desc = TensorDescriptor.from_tensor(A, [128])
1818
h = kernel.warmup(desc, 16, grid=(1, ))
19-
assert ", %arg3: i32 {tt.divisibility = 16 : i32}" in h.asm["ttir"]
19+
assert "%a: !tt.tensordesc<tensor<128xf32>>" in h.asm["ttir"]
20+
assert "%b: i32 {tt.divisibility = 16 : i32}" in h.asm["ttir"]

python/test/unit/language/test_annotations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def _kernel(X, v):
3232
h = _kernel[(1, )](torch.empty(1, device=device), 3)
3333
pfx = 'si' if signed else 'ui'
3434
if not signed and width < 64:
35-
assert "arith.extui %arg1" in h.asm["ttir"]
36-
assert f'%arg1: i{width}' in h.asm["ttir"]
35+
assert "arith.extui %v" in h.asm["ttir"]
36+
assert f'%v: i{width}' in h.asm["ttir"]
3737
assert f'arith.{pfx}tofp' in h.asm["ttir"]
3838

3939

@@ -73,13 +73,13 @@ def _kernel(ptr, val):
7373

7474
# Check that the type is properly emitted in the IR
7575
if dtype == tl.float16:
76-
assert "%arg1: f16" in h.asm["ttir"]
77-
assert "arith.extf %arg1 : f16 to f32" in h.asm["ttir"]
76+
assert "%val: f16" in h.asm["ttir"]
77+
assert "arith.extf %val : f16 to f32" in h.asm["ttir"]
7878
elif dtype == tl.bfloat16:
79-
assert "%arg1: bf16" in h.asm["ttir"]
80-
assert "arith.extf %arg1 : bf16 to f32" in h.asm["ttir"]
79+
assert "%val: bf16" in h.asm["ttir"]
80+
assert "arith.extf %val : bf16 to f32" in h.asm["ttir"]
8181
elif dtype == tl.float32:
82-
assert "%arg1: f32" in h.asm["ttir"]
82+
assert "%val: f32" in h.asm["ttir"]
8383
elif dtype == tl.float64:
84-
assert "%arg1: f64" in h.asm["ttir"]
85-
assert "arith.truncf %arg1 : f64 to f32" in h.asm["ttir"]
84+
assert "%val: f64" in h.asm["ttir"]
85+
assert "arith.truncf %val : f64 to f32" in h.asm["ttir"]

0 commit comments

Comments
 (0)