Skip to content

Commit b238e41

Browse files
authored
fix: Fix fp8 after vllm v0.11.2 bump (#1660)
Signed-off-by: Guyue Huang <[email protected]>
1 parent fab6234 commit b238e41

File tree

1 file changed

+29
-49
lines changed
  • nemo_rl/models/generation

1 file changed

+29
-49
lines changed

nemo_rl/models/generation/fp8.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ def my_run_engine_core(*args, **kwargs):
8080

8181
def monkey_patch_vllm_ray_executor(fp8_config):
8282
if fp8_config.model_parallel_size > 1:
83-
# we patch vllm's _run_workers so that before vllm initalizes the model on each rank, we execute
83+
# we patch vllm's collective_rpc so that before vllm initalizes the model on each rank, we execute
8484
# a ray remote that patches each worker with the required fp8 vllm patches
8585
from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor
8686

87-
original_run_workers = RayDistributedExecutor._run_workers
87+
original_run_workers = RayDistributedExecutor.collective_rpc
8888

8989
def patched_run_workers(self, *args, **kwargs):
9090
global fp8_patches_applied
@@ -98,7 +98,7 @@ def patched_run_workers(self, *args, **kwargs):
9898

9999
return original_run_workers(self, *args, **kwargs)
100100

101-
RayDistributedExecutor._run_workers = patched_run_workers
101+
RayDistributedExecutor.collective_rpc = patched_run_workers
102102
else:
103103
# for single gpu there is no ray, so just call the patches
104104
apply_fp8_patches(None, fp8_config)
@@ -225,8 +225,6 @@ def apply_fp8_patches(self, fp8_config):
225225
patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor)
226226
fp8_state.vllm_patches.append(patcher2, patcher3, patcher4)
227227

228-
# Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8)
229-
if global_fp8_config.kv_cache_dtype == "fp8":
230228
# Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates
231229
func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading"
232230
patcher5 = patch(func5_path, process_weights_after_loading_kv)
@@ -593,7 +591,7 @@ def process_weights_after_loading(self, layer) -> None:
593591
layer.weight.data = weight.data
594592
if hasattr(layer, "weight_scale"):
595593
# Not the first time to call this function, just need to update the data
596-
layer.weight_scale.data = weight_scale.data
594+
layer.weight_scale.copy_(weight_scale.data)
597595
else:
598596
# The first time to call this function, create a new parameter and update the tp status
599597
layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False)
@@ -609,68 +607,50 @@ def process_weights_after_loading_moe(self, layer) -> None:
609607
new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit.
610608
"""
611609
# Lazy import to avoid importing triton too early.
612-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
613-
is_rocm_aiter_moe_enabled,
614-
)
610+
from vllm._aiter_ops import rocm_aiter_ops
615611
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
616612
swap_w13_to_w31,
617613
)
618614
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
619-
expert_weight_is_col_major,
620-
requant_weight_ue8m0_inplace,
615+
deepgemm_post_process_fp8_weight_block,
621616
)
622617
from vllm.utils.deep_gemm import (
623-
get_col_major_tma_aligned_tensor,
624618
is_deep_gemm_e8m0_used,
625619
)
626620

627-
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
621+
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
628622

629623
assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized
630624
assert self.quant_config.activation_scheme == "dynamic"
631625

632626
if self.flashinfer_moe_backend is not None:
633-
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
634-
layer.w13_weight_scale_inv.data = swap_w13_to_w31(
635-
layer.w13_weight_scale_inv.data
636-
)
627+
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
628+
w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data)
629+
else:
630+
w13_weight = layer.w13_weight.data
631+
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
632+
w2_weight = layer.w2_weight.data
633+
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
637634

638635
# DeepGemm scales need to be transposed and aligned. We try to do
639636
# it ahead of time for performance reasons.
640-
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
641-
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
642-
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
643-
layer.w13_weight_scale_inv
644-
)
645-
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
646-
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
647-
layer.w2_weight_scale_inv
648-
)
649-
650-
if is_deep_gemm_e8m0_used():
651-
assert layer.weight_block_size is not None
652-
# Re-quantise the expert weights so their scales are UE8M0.
653-
block_sz = tuple(layer.weight_block_size)
654-
requant_weight_ue8m0_inplace(
655-
layer.w13_weight.data,
656-
layer.w13_weight_scale_inv.data,
657-
block_sz,
637+
if self.allow_deep_gemm:
638+
w13_weight, w13_weight_scale_inv = deepgemm_post_process_fp8_weight_block(
639+
wq=w13_weight,
640+
ws=w13_weight_scale_inv,
641+
quant_block_shape=tuple(layer.weight_block_size),
642+
use_e8m0=is_deep_gemm_e8m0_used(),
658643
)
659-
requant_weight_ue8m0_inplace(
660-
layer.w2_weight.data,
661-
layer.w2_weight_scale_inv.data,
662-
block_sz,
644+
w2_weight, w2_weight_scale_inv = deepgemm_post_process_fp8_weight_block(
645+
wq=w2_weight,
646+
ws=w2_weight_scale_inv,
647+
quant_block_shape=tuple(layer.weight_block_size),
648+
use_e8m0=is_deep_gemm_e8m0_used(),
663649
)
664-
665-
# Ensure column-major TMA alignment expected by DeepGEMM.
666-
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
667-
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
668-
layer.w13_weight_scale_inv
669-
)
670-
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
671-
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
672-
layer.w2_weight_scale_inv
673-
)
650+
layer.w13_weight.copy_(w13_weight)
651+
layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv)
652+
layer.w2_weight.copy_(w2_weight)
653+
layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv)
674654

675655

676656
def process_weights_after_loading_kv(self, layer) -> None:

0 commit comments

Comments
 (0)