Skip to content

Commit 39b8ead

Browse files
authored
[kernels] moved reinterpret to before tma creation (#7205)
1 parent a54f309 commit 39b8ead

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ def matmul_ogs(x, w, bias,
716716
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
717717
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
718718

719+
x = flex.lhs_data.reinterpret(x)
720+
w = flex.rhs_data.reinterpret(w)
719721
if opt_flags.is_persistent:
720722
x_tensor, w_tensor_and_transpose, mx_tensor_and_tranpose = _create_tma_descriptors(
721723
x=x, w=w, mx_tensor=mx_ctx.weight_scale,
@@ -736,10 +738,6 @@ def matmul_ogs(x, w, bias,
736738
x_tensor = x
737739
w_tensor, w_tma_transpose = w, False
738740
mx_tensor, mx_tma_transpose = mx_ctx.weight_scale, False
739-
if isinstance(x_tensor, torch.Tensor):
740-
x_tensor = flex.lhs_data.reinterpret(x)
741-
if isinstance(w_tensor, torch.Tensor):
742-
w_tensor = flex.rhs_data.reinterpret(w)
743741
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
744742
flex.out_data.reinterpret(memory["output"]),
745743
flex.out_data.reinterpret(out0), *out0.stride(), *out0_flex,

0 commit comments

Comments
 (0)