@@ -427,6 +427,7 @@ def matmul_ogs(x, w, bias,
427427 if not isinstance (x , Tensor ):
428428 x = Tensor (x , dtype = x .dtype )
429429 # determine shapes
430+ is_ragged = routing_data .expt_hist is not None
430431 M = x .shape [- 2 ] if gather_indx is None else gather_indx .src_indx .shape [0 ]
431432 batch_size = w .shape [0 ] if routing_data .expt_hist is None and w .ndim == 3 else 1
432433 K , N = w .shape [- 2 :]
@@ -505,19 +506,31 @@ def matmul_ogs(x, w, bias,
505506 grid = min (target_info .num_sms () - opt_flags .idle_sms , max_grid ) if opt_flags .is_persistent else max_grid
506507 # canonicalize storage
507508 has_gather = gather_indx is not None
508- x_storage = _canonicalize_storage (x .storage , 2 if has_gather else 3 , flex .lhs_data )
509+ has_scatter = writeback_idxs is not None
510+ has_gather_tma = has_gather and target_info .has_tma_gather ()
511+ has_scatter_tma = has_scatter and target_info .has_tma_gather ()
512+ y = wrap_torch_tensor (out0 .view (- 1 , out0 .shape [- 1 ]) if has_scatter else out0 .view (- 1 , * out0 .shape [- 2 :]))
513+ x_storage = _canonicalize_storage (x .storage , 2 if has_gather_tma else 3 , flex .lhs_data )
509514 w_storage = _canonicalize_storage (w .storage , 3 , flex .rhs_data )
515+ y_storage = _canonicalize_storage (y .storage , 2 if has_scatter_tma else 3 , flex .out_data )
510516 # create tma descriptor for x
511- x_has_tma = ((not has_gather ) or (has_gather and target_info .has_tma_gather ())) and opt_flags .is_persistent
512- x_block_tma = ([1 ] if has_gather else [1 , opt_flags .block_m ]) + [opt_flags .block_k ]
513- x_tensor_or_tma = x_storage .make_tma (x_block_tma ) if x_has_tma else x_storage .data
517+ x_has_tma = opt_flags .is_persistent and (has_gather_tma or not has_gather )
518+ x_tma_block_size = [1 , opt_flags .block_k ] if has_gather_tma else [1 , opt_flags .block_m , opt_flags .block_k ]
519+ x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
520+ x_tensor_or_tma = x_storage .make_tma (x_tma_block_size , x_tma_mode ) if x_has_tma else x_storage .data
521+ # create tma descriptor for y
522+ y_has_tma = opt_flags .is_persistent and (has_scatter_tma or not has_scatter )
523+ block_n = opt_flags .block_n // opt_flags .epilogue_subtile // fused_activation .reduction_n
524+ y_tma_block_size = [1 , block_n ] if has_scatter_tma else [1 , opt_flags .block_m , block_n ]
525+ y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
526+ y_tensor_or_tma = y_storage .make_tma (y_tma_block_size , y_tma_mode ) if y_has_tma else y_storage .data
514527 # create tma descriptor for w
515528 w_has_tma = opt_flags .is_persistent
516- w_tensor_or_tma = w_storage .make_tma ([1 , opt_flags .block_k , opt_flags .block_n ]) if w_has_tma else w_storage .data
529+ w_tensor_or_tma = w_storage .make_tma ([1 , opt_flags .block_k , opt_flags .block_n ], "dense" ) if w_has_tma else w_storage .data
517530 # create tma descriptor for w_scale
518531 w_scale_tensor_or_tma = w_scale
519532 w_scale_has_tma = opt_flags .is_persistent and w_scale is not None
520- w_scale_tensor_or_tma = w_scale .storage .make_tma ([opt_flags .block_n , opt_flags .block_k ]) if w_scale_has_tma else w_scale
533+ w_scale_tensor_or_tma = w_scale .storage .make_tma ([opt_flags .block_n , opt_flags .block_k ], "dense" ) if w_scale_has_tma else w_scale
521534 # canonicalize strides
522535 x_strides = [0 ]* (3 - x_storage .data .ndim ) + list (x_storage .data .stride ())
523536 x_scale_strides = x_scale .stride () if x_has_mx else (None , None , None )
@@ -529,14 +542,13 @@ def matmul_ogs(x, w, bias,
529542 # launch kernel
530543 kernels = get_kernels (epilogue .specs , fused_activation .specs )
531544 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(grid ,)](
532- flex .out_data .reinterpret (memory ["output" ]),
533- flex .out_data .reinterpret (out0 ), * out0 .stride (),
545+ y_tensor_or_tma , y_storage .data , * out0 .stride (),
534546 * ((None , out_scale , None ) if out_has_mx else out0_flex ),
535547 * out_scale_strides [- 3 :],
536548 x_tensor_or_tma , x_storage .data , * x_strides ,
537549 flex .lhs_data .scale ,
538550 None if x_scale is None else x_scale .data .view (torch .uint8 ), * x_scale_strides ,
539- w_tensor_or_tma , * w_storage .data .stride (), w_storage .data .stride ()[- 1 ] != 1 ,
551+ w_tensor_or_tma , w_storage . data , * w_storage .data .stride (), w_storage .data .stride ()[- 1 ] != 1 ,
540552 flex .rhs_data .scale ,
541553 w_scale_tensor_or_tma , * w_scale_strides ,
542554 bias , bias_stride ,
@@ -574,7 +586,8 @@ def matmul_ogs(x, w, bias,
574586 num_stages = opt_flags .num_stages ,
575587 arch = opt_flags .arch ,
576588 UPCAST_INDICES = should_upcast_indices (x , w , out0 ),
577- DISABLE_Y_TMA = out0 .stride (- 2 ) * out0 .dtype .itemsize % 16 != 0 ,
589+ X_TMA_MODE = x_tma_mode ,
590+ Y_TMA_MODE = y_tma_mode ,
578591 SWAP_XW = preprocessing_features .swap_xw ,
579592 IS_EPILOGUE_DEQUANT_MXFP8 = epilogue .specs .name == FnName .DEQUANTIZE_MXFP8 .name ,
580593 NUM_SMS = grid if opt_flags .is_persistent else 0 ,
0 commit comments