@@ -584,8 +584,7 @@ def apply_allocation(allocation: MatmulAllocation, output):
584584
585585def _create_tma_descriptors (
586586 x : torch .Tensor ,
587- x_tensor : torch .Tensor ,
588- w_tensor : torch .Tensor ,
587+ w : torch .Tensor ,
589588 mx_tensor : Optional [torch .Tensor ],
590589 routing_data : RoutingData ,
591590 mx_ctx : MicroscalingCtx ,
@@ -611,21 +610,21 @@ def _create_tma_descriptors(
611610 if (use_host_tma_descriptors ):
612611 if USE_GATHER_TMA or X_USE_LOAD_TMA :
613612 x_desc = TensorDescriptorBuilder .create_input_descriptor (
614- x_tensor , K , x .stride (1 ), x .stride (2 ),
613+ x , K , x .stride (1 ), x .stride (2 ),
615614 opt_flags .block_k , opt_flags .block_m ,
616615 USE_GATHER_TMA , X_USE_LOAD_TMA
617616 )
618617 descriptors .append (x_desc )
619618 if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
620619 w_desc = TensorDescriptorBuilder .create_weight_descriptor (
621- w_tensor , opt_flags .block_k , opt_flags .block_n , w_transpose
620+ w , opt_flags .block_k , opt_flags .block_n , w_transpose
622621 )
623- is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w_tensor .dtype == torch .uint8 )
622+ is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w .dtype == torch .uint8 )
624623 if is_microscaled_format :
625624 # Pad the inner shape to 128 for mxfp4 weights
626625 # for mixed precision fp8 x mxfp4 compute
627626 pad = 128
628- dim_to_pad = - 1 if w_transpose else - 2
627+ dim_to_pad = - 1
629628 old_size = w_desc .shape [dim_to_pad ]
630629 padded_size = math .ceil (old_size / pad ) * pad
631630 if padded_size != old_size :
@@ -645,7 +644,7 @@ def _create_tma_descriptors(
645644 # TODO: Currently all or none, instead should support a mixture
646645 # of host and device descriptors
647646 if None in descriptors or len (descriptors ) == 0 :
648- descriptors = [x_tensor , w_tensor , mx_tensor ]
647+ descriptors = [x , w , mx_tensor ]
649648 use_host_tma_descriptors = False
650649 if opt_flags .is_persistent :
651650 opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
@@ -759,9 +758,7 @@ def matmul_ogs(x, w, bias,
759758 USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
760759 X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
761760 _ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
762- x = x ,
763- x_tensor = flex .lhs_data .reinterpret (x ),
764- w_tensor = flex .rhs_data .reinterpret (w ),
761+ x = x , w = w ,
765762 mx_tensor = mx_ctx .weight_scale ,
766763 routing_data = routing_data ,
767764 mx_ctx = mx_ctx ,
@@ -777,7 +774,10 @@ def matmul_ogs(x, w, bias,
777774 w_transpose = w .stride (2 ) != 1 ,
778775 mx_transpose = mx_scale_stride_n != 1 ,
779776 )
780-
777+ if isinstance (x_tensor , torch .Tensor ):
778+ x_tensor = flex .lhs_data .reinterpret (x )
779+ if isinstance (w_tensor , torch .Tensor ):
780+ w_tensor = flex .rhs_data .reinterpret (w )
781781 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
782782 flex .out_data .reinterpret (memory ["output" ]),
783783 flex .out_data .reinterpret (out0 ), * out0 .stride (),
0 commit comments