Skip to content

Commit 6917a7f

Browse files
authored
[TUTORIAL] Remove device-pointer tma kernel parameters (#6294)
1 parent c85dd4e commit 6917a7f

File tree

1 file changed

+20
-45
lines changed

1 file changed

+20
-45
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 20 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,8 @@ def _matmul_launch_metadata(grid, kernel, args):
6161
return ret
6262

6363

64-
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
65-
66-
if HAS_TMA_DESC:
67-
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
68-
else:
69-
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
64+
HAS_TMA_DESC = supports_tma() and hasattr(tl, "nv_tma_desc_type")
65+
HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor")
7066

7167

7268
# TmaAutoTuneHelper used in htyu's PR #5622
@@ -86,49 +82,27 @@ def tma_desc_cpu_ptr(self):
8682
def __init__(self):
8783
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
8884
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
89-
if HAS_TMA_DESC:
90-
self.descriptors = {}
91-
else:
92-
self.cuda_descriptors = {}
85+
self.descriptors = {}
9386

9487
# Call this method outside of the lambda function for grid size
9588
def init_tma_descriptor(self, name):
96-
if HAS_TMA_DESC:
97-
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
98-
else:
99-
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
89+
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
10090

10191
# Call this method inside the lambda function for grid size
10292
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
103-
if HAS_TMA_DESC:
104-
desc_x = self.descriptors[name]
105-
assert desc_x.data_ptr() % 64 == 0
106-
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
107-
else:
108-
desc_x = self.cuda_descriptors[name]
109-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
110-
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
111-
desc_x.copy_(buf_x, non_blocking=True)
93+
desc_x = self.descriptors[name]
94+
assert desc_x.data_ptr() % 64 == 0
95+
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
11296

11397
# Call this method inside the lambda function for grid size
11498
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
115-
if HAS_TMA_DESC:
116-
desc_x = self.descriptors[name]
117-
assert desc_x.data_ptr() % 64 == 0
118-
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
119-
else:
120-
desc_x = self.cuda_descriptors[name]
121-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
122-
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
123-
desc_x.copy_(buf_x, non_blocking=True)
99+
desc_x = self.descriptors[name]
100+
assert desc_x.data_ptr() % 64 == 0
101+
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
124102

125103
def get_tma_descriptor_kernel_param(self, name):
126-
if HAS_TMA_DESC:
127-
assert self.descriptors[name] is not None
128-
return self.KernelParamWrapper(self.descriptors[name])
129-
else:
130-
assert self.cuda_descriptors[name] is not None
131-
return self.cuda_descriptors[name]
104+
assert self.descriptors[name] is not None
105+
return self.KernelParamWrapper(self.descriptors[name])
132106

133107

134108
def matmul_get_configs():
@@ -228,7 +202,7 @@ def matmul(a, b):
228202
key=["M", "N", "K"],
229203
)
230204
@triton.jit(launch_metadata=_matmul_launch_metadata)
231-
def matmul_tma_ws_kernel(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
205+
def matmul_kernel_tma_ws(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
232206
M, N, K, #
233207
BLOCK_SIZE_M: tl.constexpr, #
234208
BLOCK_SIZE_N: tl.constexpr, #
@@ -321,7 +295,7 @@ def grid(META):
321295
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
322296
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
323297

324-
matmul_tma_ws_kernel[grid](
298+
matmul_kernel_tma_ws[grid](
325299
desc_a, desc_b, desc_c, #
326300
M, N, K, #
327301
)
@@ -726,10 +700,11 @@ def bench(K, dtype, reps=1000, warmup_reps=10000):
726700
bench_fn(reps, warmup_reps, torch_matmul, a, b)
727701
bench_fn(reps, warmup_reps, matmul, a, b.T)
728702
bench_fn(reps, warmup_reps, matmul_persistent, a, b.T)
729-
if supports_tma():
730-
bench_fn(reps, warmup_reps, matmul_tma_ws, a, b)
703+
if HAS_TMA_DESC:
731704
bench_fn(reps, warmup_reps, matmul_tma_persistent, a, b)
705+
if HAS_TENSOR_DESC:
732706
bench_fn(reps, warmup_reps, matmul_descriptor_persistent, a, b)
707+
bench_fn(reps, warmup_reps, matmul_tma_ws, a, b)
733708

734709

735710
def validate(M, N, K, dtype):
@@ -740,10 +715,10 @@ def validate(M, N, K, dtype):
740715
torch_result = torch_matmul(a, b) if dtype == torch.float16 else None
741716
cublas_result = cublas_matmul(a, b) if cublas is not None else None
742717
naive_result = matmul(a, b.T)
743-
tma_ws_result = matmul_tma_ws(a, b) if supports_tma() else None
718+
tma_ws_result = matmul_tma_ws(a, b) if HAS_TENSOR_DESC else None
744719
persistent_result = matmul_persistent(a, b.T)
745-
tma_persistent_result = matmul_tma_persistent(a, b) if supports_tma() else None
746-
descriptor_persistent_result = matmul_descriptor_persistent(a, b) if supports_tma() else None
720+
tma_persistent_result = matmul_tma_persistent(a, b) if HAS_TMA_DESC else None
721+
descriptor_persistent_result = matmul_descriptor_persistent(a, b) if HAS_TENSOR_DESC else None
747722

748723
if tma_ws_result is not None:
749724
naive_vs_tma_ws = "✅" if torch.allclose(naive_result.to(torch.float16), tma_ws_result.to(torch.float16),

0 commit comments

Comments
 (0)