Skip to content

Commit 89e4050

Browse files
[Bug] Fix Weight Loading for Block FP8 Cutlass SM90 (vllm-project#25909)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: Wentao Ye <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 78a47f8 commit 89e4050

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
911911
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
912912
# requantize the weight and input to the specific scale
913913
# at the same time.
914-
if is_deep_gemm_e8m0_used():
914+
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
915+
layer.orig_dtype, layer.weight)
916+
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
915917
block_sz = tuple(layer.weight_block_size)
916918
requant_weight_ue8m0_inplace(layer.weight.data,
917919
layer.weight_scale.data, block_sz)
918920
# SM90 Block FP8 CUTLASS requires row-major weight scales
919921
elif (current_platform.is_device_capability(90)
920-
and cutlass_block_fp8_supported
921-
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
922-
layer.weight)):
922+
and cutlass_block_fp8_supported and not should_use_deepgemm):
923923
layer.weight_scale = torch.nn.Parameter(
924924
layer.weight_scale.data.T.contiguous(), requires_grad=False)
925925

0 commit comments

Comments
 (0)