@@ -716,6 +716,8 @@ def matmul_ogs(x, w, bias,
716
716
expt_token_offs_raw = None if expt_data is None else expt_data .token_offs_raw
717
717
expt_block_pid_map = None if expt_data is None else expt_data .block_pid_map [block_m ]
718
718
719
+ x = flex .lhs_data .reinterpret (x )
720
+ w = flex .rhs_data .reinterpret (w )
719
721
if opt_flags .is_persistent :
720
722
x_tensor , w_tensor_and_transpose , mx_tensor_and_tranpose = _create_tma_descriptors (
721
723
x = x , w = w , mx_tensor = mx_ctx .weight_scale ,
@@ -736,10 +738,6 @@ def matmul_ogs(x, w, bias,
736
738
x_tensor = x
737
739
w_tensor , w_tma_transpose = w , False
738
740
mx_tensor , mx_tma_transpose = mx_ctx .weight_scale , False
739
- if isinstance (x_tensor , torch .Tensor ):
740
- x_tensor = flex .lhs_data .reinterpret (x )
741
- if isinstance (w_tensor , torch .Tensor ):
742
- w_tensor = flex .rhs_data .reinterpret (w )
743
741
(kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
744
742
flex .out_data .reinterpret (memory ["output" ]),
745
743
flex .out_data .reinterpret (out0 ), * out0 .stride (), * out0_flex ,
0 commit comments