Skip to content

Commit 00c133b

Browse files
peterbell10FindHao
authored andcommitted
[FRONTEND] Fix argument annotation after host side TensorDescriptor (triton-lang#6509)
`val_paths.index(attr_path)` computes the index into the python values after flattening tuples. Instead, we can use `cursor` which corresponds to the block argument index in the ttir.
1 parent 853e1da commit 00c133b

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,3 +1402,17 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
14021402
# TODO: The use of stmatrix for Blackwell is currently not supported.
14031403
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
14041404
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
1405+
1406+
1407+
@requires_tma
1408+
def test_specialization_after_host_tensordesc():
1409+
1410+
@triton.jit
1411+
def kernel(a, b):
1412+
pass
1413+
1414+
device = "cuda"
1415+
A = torch.randn(1024, device=device)
1416+
desc = TensorDescriptor.from_tensor(A, [128])
1417+
h = kernel.warmup(desc, 16, grid=(1, ))
1418+
assert ", %arg3: i32 {tt.divisibility = 16 : i32}" in h.asm["ttir"]

python/triton/compiler/code_generator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,16 @@ def make_template(ty):
253253
vals = make_template(self.arg_types)
254254
is_val = lambda path, _: path not in self.constants and _ is not None
255255
val_paths = list(find_paths_if(self.arg_types, is_val))
256-
# > set attributes
257-
for attr_path, attr_specs in self.attrs.items():
258-
for attr_name, attr_val in attr_specs:
259-
if attr_path in val_paths:
260-
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
261256
# > add IR values to the template
262257
cursor = 0
263258
handles = [fn.args(i) for i in range(fn.get_num_args())]
264259
for path in val_paths:
265260
ty = get_iterable_path(self.arg_types, path)
261+
# > set attributes
262+
attr_specs = self.attrs.get(path, [])
263+
for attr_name, attr_val in attr_specs:
264+
fn.set_arg_attr(cursor, attr_name, attr_val)
265+
# > build frontend value
266266
val, cursor = ty._unflatten_ir(handles, cursor)
267267
set_iterable_path(vals, path, val)
268268
# > add constexpr values to the template

0 commit comments

Comments
 (0)