Skip to content

Commit e0cb184

Browse files
alexsamardzicpytorchmergebot
authored andcommitted
Use TMA loads always for Triton grouped MM kernel (pytorch#164256)
Pull Request resolved: pytorch#164256 Approved by: https://github.com/ngimel
1 parent a4110fe commit e0cb184

File tree

1 file changed

+100
-28
lines changed

1 file changed

+100
-28
lines changed

torch/_inductor/kernel/mm_grouped.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def grouped_mm_configs():
6565
return _NV_CONFIGS
6666

6767

68-
def early_config_prune(g, m, configs, named_args):
69-
dtsize = 1
68+
def early_config_prune(g, m, dtsize, configs, named_args):
7069
pruned_configs = []
7170
for config in configs:
7271
kw = config.kwargs
@@ -186,15 +185,29 @@ def early_config_prune(g, m, configs, named_args):
186185
{%- endif %}
187186
a_ptr,
188187
{%- if A_IS_2D %}
188+
{%- if A_IS_K_MAJOR %}
189189
shape=[M, K],
190190
# fixme: strides=[A_STRIDE_M, A_STRIDE_K],
191191
strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
192192
block_shape=[BLOCK_M, BLOCK_K],
193193
{%- else %}
194+
shape=[K, M],
195+
# fixme: strides=[A_STRIDE_K, A_STRIDE_M],
196+
strides=[{{stride("a_ptr", -1)}}, {{stride("a_ptr", -2)}}],
197+
block_shape=[BLOCK_K, BLOCK_M],
198+
{%- endif %}
199+
{%- else %}
200+
{%- if A_IS_K_MAJOR %}
194201
shape=[G, M, K],
195202
# fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K],
196203
strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}],
197204
block_shape=[1, BLOCK_M, BLOCK_K],
205+
{%- else %}
206+
shape=[G, K, M],
207+
# fixme: strides=[A_STRIDE_G, A_STRIDE_K, A_STRIDE_M],
208+
strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -1)}}, {{stride("a_ptr", -2)}}],
209+
block_shape=[1, BLOCK_K, BLOCK_M],
210+
{%- endif %}
198211
{%- endif %}
199212
)
200213
@@ -205,15 +218,29 @@ def early_config_prune(g, m, configs, named_args):
205218
{%- endif %}
206219
b_ptr,
207220
{%- if B_IS_2D %}
221+
{%- if B_IS_K_MAJOR %}
208222
shape=[N, K],
209223
# fixme: strides=[B_STRIDE_N, B_STRIDE_K],
210224
strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
211225
block_shape=[BLOCK_N, BLOCK_K],
212226
{%- else %}
227+
shape=[K, N],
228+
# fixme: strides=[B_STRIDE_K, B_STRIDE_N],
229+
strides=[{{stride("b_ptr", -2)}}, {{stride("b_ptr", -1)}}],
230+
block_shape=[BLOCK_K, BLOCK_N],
231+
{%- endif %}
232+
{%- else %}
233+
{%- if B_IS_K_MAJOR %}
213234
shape=[G, N, K],
214235
# fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K],
215236
strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}],
216237
block_shape=[1, BLOCK_N, BLOCK_K],
238+
{%- else %}
239+
shape=[G, K, N],
240+
# fixme: strides=[B_STRIDE_G, B_STRIDE_K, B_STRIDE_N],
241+
strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -2)}}, {{stride("b_ptr", -1)}}],
242+
block_shape=[1, BLOCK_K, BLOCK_N],
243+
{%- endif %}
217244
{%- endif %}
218245
)
219246
{%- endif %}
@@ -286,39 +313,82 @@ def early_config_prune(g, m, configs, named_args):
286313
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
287314
288315
{%- if USE_TMA_LOAD %}
289-
m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32)
290-
n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32)
316+
m_tile_offset = tile_m_idx * BLOCK_M
317+
n_tile_offset = tile_n_idx * BLOCK_N
318+
m_offset = (m_start_offset + m_tile_offset).to(tl.int32)
319+
n_offset = (n_start_offset + n_tile_offset).to(tl.int32)
291320
292-
for k_offset in range(0, k_size, BLOCK_K):
321+
for k_block_offset in range(0, k_size, BLOCK_K):
293322
{%- if A_IS_2D %}
294-
a = a_desc.load([m_offset, k_start_offset + k_offset])
323+
{%- if A_IS_K_MAJOR %}
324+
a = a_desc.load([m_offset, k_start_offset + k_block_offset])
325+
{%- else %}
326+
a = a_desc.load([k_start_offset + k_block_offset, m_offset])
327+
{%- endif %}
328+
{%- else %}
329+
{%- if A_IS_K_MAJOR %}
330+
a = a_desc.load([g, m_offset, k_start_offset + k_block_offset]).reshape(BLOCK_M, BLOCK_K)
295331
{%- else %}
296-
a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K)
332+
a = a_desc.load([g, k_start_offset + k_block_offset, m_offset]).reshape(BLOCK_K, BLOCK_M)
333+
{%- endif %}
297334
{%- endif %}
298335
{%- if B_IS_2D %}
299-
b = b_desc.load([n_offset, k_start_offset + k_offset])
336+
{%- if B_IS_K_MAJOR %}
337+
b = b_desc.load([n_offset, k_start_offset + k_block_offset])
300338
{%- else %}
301-
b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K)
339+
b = b_desc.load([k_start_offset + k_block_offset, n_offset])
340+
{%- endif %}
341+
{%- else %}
342+
{%- if B_IS_K_MAJOR %}
343+
b = b_desc.load([g, n_offset, k_start_offset + k_block_offset]).reshape(BLOCK_N, BLOCK_K)
344+
{%- else %}
345+
b = b_desc.load([g, k_start_offset + k_block_offset, n_offset]).reshape(BLOCK_K, BLOCK_N)
346+
{%- endif %}
302347
{%- endif %}
303348
304349
{%- if K_IS_VARYING %}
305-
if k_offset + BLOCK_K > k_size:
306-
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
307-
a = tl.where(group_offs_k < k_size, a, 0)
308-
b = tl.where(group_offs_k < k_size, b, 0)
350+
if k_block_offset + BLOCK_K > k_size:
351+
group_offs = k_block_offset + tl.arange(0, BLOCK_K)
352+
k_mask = group_offs < k_size
353+
{%- if A_IS_K_MAJOR %}
354+
a = tl.where(k_mask[None, :], a, 0)
355+
{%- else %}
356+
a = tl.where(k_mask[:, None], a, 0)
357+
{%- endif %}
358+
{%- if B_IS_K_MAJOR %}
359+
b = tl.where(k_mask[None, :], b, 0)
360+
{%- else %}
361+
b = tl.where(k_mask[:, None], b, 0)
362+
{%- endif %}
309363
{%- endif %}
310364
311365
{%- if USE_FAST_ACCUM %}
366+
{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
312367
accumulator = tl.dot(a, b.T, accumulator)
368+
{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
369+
accumulator = tl.dot(a, b, accumulator)
370+
{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
371+
accumulator = tl.dot(a.T, b.T, accumulator)
313372
{%- else %}
373+
accumulator = tl.dot(a.T, b, accumulator)
374+
{%- endif %}
375+
{%- else %}
376+
{%- if A_IS_K_MAJOR and B_IS_K_MAJOR %}
314377
accumulator += tl.dot(a, b.T)
378+
{%- elif A_IS_K_MAJOR and not B_IS_K_MAJOR %}
379+
accumulator += tl.dot(a, b)
380+
{%- elif not A_IS_K_MAJOR and B_IS_K_MAJOR %}
381+
accumulator += tl.dot(a.T, b.T)
382+
{%- else %}
383+
accumulator += tl.dot(a.T, b)
384+
{%- endif %}
315385
{%- endif %}
316386
{%- else %}
317387
offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
318388
offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)
319-
for k_offset in range(0, k_size, BLOCK_K):
320-
group_offs_k = k_offset + tl.arange(0, BLOCK_K)
321-
offs_k = group_offs_k + k_start_offset
389+
for k_block_offset in range(0, k_size, BLOCK_K):
390+
block_offs_k = k_block_offset + tl.arange(0, BLOCK_K)
391+
offs_k = block_offs_k + k_start_offset
322392
a_ptrs = (
323393
a_ptr
324394
{%- if not A_IS_2D %}
@@ -335,10 +405,10 @@ def early_config_prune(g, m, configs, named_args):
335405
+ (n_start_offset + offs_bn[:, None]) * B_STRIDE_N
336406
+ offs_k[None, :] * B_STRIDE_K
337407
)
338-
a_mask = (offs_am[:, None] < m_size) & (group_offs_k[None, :] < k_size)
339-
b_mask = (offs_bn[:, None] < n_size) & (group_offs_k[None, :] < k_size)
340-
a = tl.load(a_ptrs, mask=a_mask, other=0)
341-
b = tl.load(b_ptrs, mask=b_mask, other=0)
408+
a_mask = (offs_am[:, None] < m_size) & (block_offs_k[None, :] < k_size)
409+
b_mask = (offs_bn[:, None] < n_size) & (block_offs_k[None, :] < k_size)
410+
a = tl.load(a_ptrs, mask=a_mask, other=tl.zeros((), dtype=a_ptrs.dtype.element_ty))
411+
b = tl.load(b_ptrs, mask=b_mask, other=tl.zeros((), dtype=b_ptrs.dtype.element_ty))
342412
{%- if USE_FAST_ACCUM %}
343413
accumulator = tl.dot(a, b.T, accumulator)
344414
{%- else %}
@@ -360,7 +430,7 @@ def early_config_prune(g, m, configs, named_args):
360430
{%- endif %}
361431
+ offs_am[:, None],
362432
mask=offs_am[:, None] < m_size,
363-
other=0,
433+
other=tl.zeros((), dtype=scale_a_ptr.dtype.element_ty),
364434
)
365435
scale_b = tl.load(
366436
scale_b_ptr
@@ -371,7 +441,7 @@ def early_config_prune(g, m, configs, named_args):
371441
{%- endif %}
372442
+ offs_bn[None, :],
373443
mask=offs_bn[None, :] < n_size,
374-
other=0,
444+
other=tl.zeros((), dtype=scale_b_ptr.dtype.element_ty),
375445
)
376446
c = accumulator.to(tl.float32) * scale_a * scale_b
377447
{%- else %}
@@ -648,6 +718,9 @@ def _tuned_grouped_mm_common(
648718
V.graph.sizevars.check_equals(k1, k2)
649719
a_is_2d, b_is_2d = False, False
650720

721+
a_is_k_major = mat_a.get_stride()[-1] == 1
722+
b_is_k_major = mat_b.get_stride()[-2] == 1
723+
651724
triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor")
652725
triton_has_experimental_make_tensor_descriptor = hasattr(
653726
tl, "_experimental_make_tensor_descriptor"
@@ -656,22 +729,21 @@ def _tuned_grouped_mm_common(
656729
triton_has_make_tensor_descriptor
657730
or triton_has_experimental_make_tensor_descriptor
658731
)
659-
# The make_tensor_descriptor imposes this additional limitation.
660-
use_tma_load = use_tma_load and (
661-
mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1
662-
)
663-
664732
kwargs = {
665733
"SCALED": scaled,
666734
"A_IS_2D": a_is_2d,
667735
"B_IS_2D": b_is_2d,
736+
"A_IS_K_MAJOR": a_is_k_major,
737+
"B_IS_K_MAJOR": b_is_k_major,
668738
"USE_FAST_ACCUM": use_fast_accum,
669739
"NUM_SMS": get_num_sms(),
670740
"USE_TMA_LOAD": use_tma_load,
671741
"USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor,
672742
}
673743

674-
for config in early_config_prune(g, m, grouped_mm_configs(), kwargs):
744+
for config in early_config_prune(
745+
g, m, mat_a.dtype.itemsize, grouped_mm_configs(), kwargs
746+
):
675747
kernel_template.maybe_append_choice(
676748
choices,
677749
input_nodes=input_nodes,

0 commit comments

Comments
 (0)