66import torch
77import triton
88from enum import Enum , auto
9+ import math
910# utilities
1011from triton_kernels import target_info
1112from triton_kernels .numerics import InFlexData , OutFlexData
@@ -427,6 +428,7 @@ def matmul_ogs(x, w, bias,
427428 if not isinstance (x , Tensor ):
428429 x = Tensor (x , dtype = x .dtype )
429430 # determine shapes
431+ is_ragged = routing_data .expt_hist is not None
430432 M = x .shape [- 2 ] if gather_indx is None else gather_indx .src_indx .shape [0 ]
431433 batch_size = w .shape [0 ] if routing_data .expt_hist is None and w .ndim == 3 else 1
432434 K , N = w .shape [- 2 :]
@@ -457,6 +459,11 @@ def matmul_ogs(x, w, bias,
457459 opt_flags , preprocessing_features , postprocessing_features
458460 )
459461 memory = apply_allocation (allocation , y )
462+ if batch_size * M * N == 0 :
463+ ret = memory ["output" ].squeeze (0 )
464+ if not is_input_batched :
465+ ret = ret .squeeze (0 )
466+ return ret
460467 # TMA descriptors require a global memory allocation
461468 if opt_flags .is_persistent :
462469 triton .set_allocator (get_per_device_per_stream_alloc_fn (x .device ))
@@ -505,19 +512,31 @@ def matmul_ogs(x, w, bias,
505512 grid = min (target_info .num_sms () - opt_flags .idle_sms , max_grid ) if opt_flags .is_persistent else max_grid
506513 # canonicalize storage
507514 has_gather = gather_indx is not None
508- x_storage = _canonicalize_storage (x .storage , 2 if has_gather else 3 , flex .lhs_data )
515+ has_scatter = writeback_idxs is not None
516+ has_gather_tma = has_gather and target_info .has_tma_gather ()
517+ has_scatter_tma = has_scatter and target_info .has_tma_gather ()
518+ y = wrap_torch_tensor (out0 .view (math .prod (out0 .shape [:- 1 ]), out0 .shape [- 1 ]) if has_scatter else out0 .view (math .prod (out0 .shape [:- 2 ]), * out0 .shape [- 2 :]))
519+ x_storage = _canonicalize_storage (x .storage , 2 if has_gather_tma else 3 , flex .lhs_data )
509520 w_storage = _canonicalize_storage (w .storage , 3 , flex .rhs_data )
521+ y_storage = _canonicalize_storage (y .storage , 2 if has_scatter_tma else 3 , flex .out_data )
510522 # create tma descriptor for x
511- x_has_tma = ((not has_gather ) or (has_gather and target_info .has_tma_gather ())) and opt_flags .is_persistent
512- x_block_tma = ([1 ] if has_gather else [1 , opt_flags .block_m ]) + [opt_flags .block_k ]
513- x_tensor_or_tma = x_storage .make_tma (x_block_tma ) if x_has_tma else x_storage .data
523+ x_has_tma = opt_flags .is_persistent and (has_gather_tma or not has_gather )
524+ x_tma_block_size = [1 , opt_flags .block_k ] if has_gather_tma else [1 , opt_flags .block_m , opt_flags .block_k ]
525+ x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
526+ x_tensor_or_tma = x_storage .make_tma (x_tma_block_size , x_tma_mode ) if x_has_tma else x_storage .data
527+ # create tma descriptor for y
528+ y_has_tma = opt_flags .is_persistent and (has_scatter_tma or not has_scatter )
529+ block_n = opt_flags .block_n // opt_flags .epilogue_subtile // fused_activation .reduction_n
530+ y_tma_block_size = [1 , block_n ] if has_scatter_tma else [1 , opt_flags .block_m , block_n ]
531+ y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
532+ y_tensor_or_tma = y_storage .make_tma (y_tma_block_size , y_tma_mode ) if y_has_tma else y_storage .data
514533 # create tma descriptor for w
515534 w_has_tma = opt_flags .is_persistent
516- w_tensor_or_tma = w_storage .make_tma ([1 , opt_flags .block_k , opt_flags .block_n ]) if w_has_tma else w_storage .data
535+ w_tensor_or_tma = w_storage .make_tma ([1 , opt_flags .block_k , opt_flags .block_n ], "dense" ) if w_has_tma else w_storage .data
517536 # create tma descriptor for w_scale
518537 w_scale_tensor_or_tma = w_scale
519538 w_scale_has_tma = opt_flags .is_persistent and w_scale is not None
520- w_scale_tensor_or_tma = w_scale .storage .make_tma ([opt_flags .block_n , opt_flags .block_k ]) if w_scale_has_tma else w_scale
539+ w_scale_tensor_or_tma = w_scale .storage .make_tma ([opt_flags .block_n , opt_flags .block_k ], "dense" ) if w_scale_has_tma else w_scale
521540 # canonicalize strides
522541 x_strides = [0 ]* (3 - x_storage .data .ndim ) + list (x_storage .data .stride ())
523542 x_scale_strides = x_scale .stride () if x_has_mx else (None , None , None )
@@ -529,14 +548,13 @@ def matmul_ogs(x, w, bias,
529548 # launch kernel
530549 kernels = get_kernels (epilogue .specs , fused_activation .specs )
531550 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(grid ,)](
532- flex .out_data .reinterpret (memory ["output" ]),
533- flex .out_data .reinterpret (out0 ), * out0 .stride (),
551+ y_tensor_or_tma , y_storage .data , * out0 .stride (),
534552 * ((None , out_scale , None ) if out_has_mx else out0_flex ),
535553 * out_scale_strides [- 3 :],
536554 x_tensor_or_tma , x_storage .data , * x_strides ,
537555 flex .lhs_data .scale ,
538556 None if x_scale is None else x_scale .data .view (torch .uint8 ), * x_scale_strides ,
539- w_tensor_or_tma , * w_storage .data .stride (), w_storage .data .stride ()[- 1 ] != 1 ,
557+ w_tensor_or_tma , w_storage . data , * w_storage .data .stride (), w_storage .data .stride ()[- 1 ] != 1 ,
540558 flex .rhs_data .scale ,
541559 w_scale_tensor_or_tma , * w_scale_strides ,
542560 bias , bias_stride ,
@@ -574,7 +592,8 @@ def matmul_ogs(x, w, bias,
574592 num_stages = opt_flags .num_stages ,
575593 arch = opt_flags .arch ,
576594 UPCAST_INDICES = should_upcast_indices (x , w , out0 ),
577- DISABLE_Y_TMA = out0 .stride (- 2 ) * out0 .dtype .itemsize % 16 != 0 ,
595+ X_TMA_MODE = x_tma_mode ,
596+ Y_TMA_MODE = y_tma_mode ,
578597 SWAP_XW = preprocessing_features .swap_xw ,
579598 IS_EPILOGUE_DEQUANT_MXFP8 = epilogue .specs .name == FnName .DEQUANTIZE_MXFP8 .name ,
580599 NUM_SMS = grid if opt_flags .is_persistent else 0 ,
0 commit comments