@@ -584,8 +584,7 @@ def apply_allocation(allocation: MatmulAllocation, output):
584
584
585
585
def _create_tma_descriptors (
586
586
x : torch .Tensor ,
587
- x_tensor : torch .Tensor ,
588
- w_tensor : torch .Tensor ,
587
+ w : torch .Tensor ,
589
588
mx_tensor : Optional [torch .Tensor ],
590
589
routing_data : RoutingData ,
591
590
mx_ctx : MicroscalingCtx ,
@@ -611,21 +610,21 @@ def _create_tma_descriptors(
611
610
if (use_host_tma_descriptors ):
612
611
if USE_GATHER_TMA or X_USE_LOAD_TMA :
613
612
x_desc = TensorDescriptorBuilder .create_input_descriptor (
614
- x_tensor , K , x .stride (1 ), x .stride (2 ),
613
+ x , K , x .stride (1 ), x .stride (2 ),
615
614
opt_flags .block_k , opt_flags .block_m ,
616
615
USE_GATHER_TMA , X_USE_LOAD_TMA
617
616
)
618
617
descriptors .append (x_desc )
619
618
if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
620
619
w_desc = TensorDescriptorBuilder .create_weight_descriptor (
621
- w_tensor , opt_flags .block_k , opt_flags .block_n , w_transpose
620
+ w , opt_flags .block_k , opt_flags .block_n , w_transpose
622
621
)
623
- is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w_tensor .dtype == torch .uint8 )
622
+ is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w .dtype == torch .uint8 )
624
623
if is_microscaled_format :
625
624
# Pad the inner shape to 128 for mxfp4 weights
626
625
# for mixed precision fp8 x mxfp4 compute
627
626
pad = 128
628
- dim_to_pad = - 1 if w_transpose else - 2
627
+ dim_to_pad = - 1
629
628
old_size = w_desc .shape [dim_to_pad ]
630
629
padded_size = math .ceil (old_size / pad ) * pad
631
630
if padded_size != old_size :
@@ -645,7 +644,7 @@ def _create_tma_descriptors(
645
644
# TODO: Currently all or none, instead should support a mixture
646
645
# of host and device descriptors
647
646
if None in descriptors or len (descriptors ) == 0 :
648
- descriptors = [x_tensor , w_tensor , mx_tensor ]
647
+ descriptors = [x , w , mx_tensor ]
649
648
use_host_tma_descriptors = False
650
649
if opt_flags .is_persistent :
651
650
opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
@@ -759,9 +758,7 @@ def matmul_ogs(x, w, bias,
759
758
USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
760
759
X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
761
760
_ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
762
- x = x ,
763
- x_tensor = flex .lhs_data .reinterpret (x ),
764
- w_tensor = flex .rhs_data .reinterpret (w ),
761
+ x = x , w = w ,
765
762
mx_tensor = mx_ctx .weight_scale ,
766
763
routing_data = routing_data ,
767
764
mx_ctx = mx_ctx ,
@@ -777,7 +774,10 @@ def matmul_ogs(x, w, bias,
777
774
w_transpose = w .stride (2 ) != 1 ,
778
775
mx_transpose = mx_scale_stride_n != 1 ,
779
776
)
780
-
777
+ if isinstance (x_tensor , torch .Tensor ):
778
+ x_tensor = flex .lhs_data .reinterpret (x )
779
+ if isinstance (w_tensor , torch .Tensor ):
780
+ w_tensor = flex .rhs_data .reinterpret (w )
781
781
(kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
782
782
flex .out_data .reinterpret (memory ["output" ]),
783
783
flex .out_data .reinterpret (out0 ), * out0 .stride (),
0 commit comments