Skip to content

Commit 5f3ab63

Browse files
authored
mx: make CUDA kernel for dim1 cast in mxfp8_cublas recipe (#2661)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c993d64 commit 5f3ab63

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchao/prototype/mx_formats/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ def from_recipe_name(
184184
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
185185
return MXLinearConfig()
186186
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
187-
# TODO(future PR): default to CUDA dim1 kernel
188-
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
187+
return MXLinearConfig(
188+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
189+
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
190+
)
189191
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
190192
return MXLinearConfig(
191193
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,

0 commit comments

Comments
 (0)