We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c993d64 commit 5f3ab63Copy full SHA for 5f3ab63
torchao/prototype/mx_formats/config.py
@@ -184,8 +184,10 @@ def from_recipe_name(
184
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
185
return MXLinearConfig()
186
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
187
- # TODO(future PR): default to CUDA dim1 kernel
188
- return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
+ return MXLinearConfig(
+ gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
189
+ mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
190
+ )
191
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
192
return MXLinearConfig(
193
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
0 commit comments