2
2
import pytest
3
3
import torch
4
4
from typing import Union
5
+ import triton
5
6
# routing utilities
6
7
from triton_kernels .routing import routing
7
8
# matmul utilities
@@ -243,6 +244,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
243
244
# Automatic padding not implemented for Hopper swizzle
244
245
pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
245
246
247
+ # launch metadata for batched / mx types may not work yet.
248
+ test_launch_metadata = (mode == "ragged" ) and ("mx" not in weight_dtype_str )
249
+
246
250
torch .manual_seed (0 )
247
251
248
252
block_k = None
@@ -314,8 +318,48 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
314
318
315
319
if w_tri .shape [0 ] == 1 :
316
320
# Test the case when weight has dim 2, i.e., shape (K, N).
317
- w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ ()
318
- w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ ()
321
+ w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
322
+ w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
323
+
324
+ if test_launch_metadata :
325
+
326
+ def _clobber (t , used_mask ):
327
+ # Fill the unread part of the tensor with garbage, to be sure that
328
+ # we don't actually read from the part.
329
+ if len (used_mask ) == 1 :
330
+ return
331
+ elif t .element_size () == 1 :
332
+ t .view (torch .int8 )[~ used_mask ] = 127
333
+ else :
334
+ t [~ used_mask ] = torch .inf
335
+
336
+ if rdata is not None :
337
+ n_tokens = rdata .expt_hist .sum ().item ()
338
+ used_expts = (rdata .expt_hist > 0 )
339
+ _clobber (w_tri , used_expts )
340
+ n_w_bytes = used_expts .sum ().item () * n * k * w_tri .element_size ()
341
+ else :
342
+ n_tokens = m
343
+ n_w_bytes = w_tri .numel () * w_tri .element_size ()
344
+
345
+ if gindx is not None :
346
+ used_x_rows = (gindx .dst_indx .view (- 1 , n_expts_act ) != - 1 ).any (dim = 1 )
347
+ _clobber (x_tri , used_x_rows )
348
+ n_x_bytes = used_x_rows .sum ().item () * k * x_tri .element_size ()
349
+ elif rdata is not None :
350
+ n_x_bytes = n_tokens * k * x_tri .element_size ()
351
+ else :
352
+ n_x_bytes = x_tri .numel () * x_tri .element_size ()
353
+
354
+ nbytes = None
355
+
356
+ def _hook (launch_metadata ):
357
+ nonlocal nbytes
358
+ metadata = launch_metadata .get ()
359
+ if "matmul_ogs" in metadata ["name" ]:
360
+ nbytes = metadata ["bytes" ]
361
+
362
+ triton .knobs .runtime .launch_enter_hook = _hook
319
363
320
364
if mode == "batched" :
321
365
rdata , gindx , sindx = None , None , None
@@ -327,6 +371,16 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327
371
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
328
372
y_scale = flex .out_data .expected_scale if act_is_float8 else 1
329
373
374
+ if test_launch_metadata :
375
+ if gindx is not None :
376
+ n_y_bytes = (gindx .src_indx != - 1 ).sum ().item () * n * tri_y .element_size ()
377
+ elif rdata is not None :
378
+ n_y_bytes = n_tokens * n * tri_y .element_size ()
379
+ else :
380
+ n_y_bytes = tri_y .numel () * tri_y .element_size ()
381
+ assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes
382
+ triton .knobs .runtime .launch_enter_hook = None
383
+
330
384
def round_x (x , idx ):
331
385
return x .to (act_dtype ).to (torch .float32 ) if sep_gather else x
332
386
0 commit comments