Skip to content

Commit 5d5146e

Browse files
authored
[CI/Build] Conditionally register cutlass_fp4_group_mm to fix building on Hopper (vllm-project#26138)
Signed-off-by: mgoin <[email protected]>
1 parent 2aaa423 commit 5d5146e

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include "core/registration.h"
18+
1719
#include <torch/all.h>
1820
#include <cutlass/arch/arch.h>
1921

@@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
418420
"12.8 or above.");
419421
#endif
420422
}
423+
424+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
425+
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
426+
}

csrc/torch_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
397397
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
398398
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()",
399399
{stride_tag});
400-
ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
400+
// conditionally compiled so impl registration is in source file
401401

402402
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
403403
// quantization, as well as bias

0 commit comments

Comments
 (0)