File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
vllm/model_executor/layers/quantization/utils Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
911
911
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
912
912
# requantize the weight and input to the specific scale
913
913
# at the same time.
914
- if is_deep_gemm_e8m0_used ():
914
+ should_use_deepgemm = should_use_deepgemm_for_fp8_linear (
915
+ layer .orig_dtype , layer .weight )
916
+ if is_deep_gemm_e8m0_used () and should_use_deepgemm :
915
917
block_sz = tuple (layer .weight_block_size )
916
918
requant_weight_ue8m0_inplace (layer .weight .data ,
917
919
layer .weight_scale .data , block_sz )
918
920
# SM90 Block FP8 CUTLASS requires row-major weight scales
919
921
elif (current_platform .is_device_capability (90 )
920
- and cutlass_block_fp8_supported
921
- and not should_use_deepgemm_for_fp8_linear (torch .bfloat16 ,
922
- layer .weight )):
922
+ and cutlass_block_fp8_supported and not should_use_deepgemm ):
923
923
layer .weight_scale = torch .nn .Parameter (
924
924
layer .weight_scale .data .T .contiguous (), requires_grad = False )
925
925
You can’t perform that action at this time.
0 commit comments