@@ -80,11 +80,11 @@ def my_run_engine_core(*args, **kwargs):
8080
8181def monkey_patch_vllm_ray_executor (fp8_config ):
8282 if fp8_config .model_parallel_size > 1 :
83- # we patch vllm's _run_workers so that before vllm initalizes the model on each rank, we execute
83+ # we patch vllm's collective_rpc so that before vllm initalizes the model on each rank, we execute
8484 # a ray remote that patches each worker with the required fp8 vllm patches
8585 from vllm .v1 .executor .ray_distributed_executor import RayDistributedExecutor
8686
87- original_run_workers = RayDistributedExecutor ._run_workers
87+ original_run_workers = RayDistributedExecutor .collective_rpc
8888
8989 def patched_run_workers (self , * args , ** kwargs ):
9090 global fp8_patches_applied
@@ -98,7 +98,7 @@ def patched_run_workers(self, *args, **kwargs):
9898
9999 return original_run_workers (self , * args , ** kwargs )
100100
101- RayDistributedExecutor ._run_workers = patched_run_workers
101+ RayDistributedExecutor .collective_rpc = patched_run_workers
102102 else :
103103 # for single gpu there is no ray, so just call the patches
104104 apply_fp8_patches (None , fp8_config )
@@ -225,8 +225,6 @@ def apply_fp8_patches(self, fp8_config):
225225 patcher4 = patch (func4_path , _per_token_group_quant_fp8_colmajor )
226226 fp8_state .vllm_patches .append (patcher2 , patcher3 , patcher4 )
227227
228- # Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8)
229- if global_fp8_config .kv_cache_dtype == "fp8" :
230228 # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates
231229 func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading"
232230 patcher5 = patch (func5_path , process_weights_after_loading_kv )
@@ -593,7 +591,7 @@ def process_weights_after_loading(self, layer) -> None:
593591 layer .weight .data = weight .data
594592 if hasattr (layer , "weight_scale" ):
595593 # Not the first time to call this function, just need to update the data
596- layer .weight_scale .data = weight_scale .data
594+ layer .weight_scale .copy_ ( weight_scale .data )
597595 else :
598596 # The first time to call this function, create a new parameter and update the tp status
599597 layer .weight_scale = torch .nn .Parameter (weight_scale .data , requires_grad = False )
@@ -609,68 +607,50 @@ def process_weights_after_loading_moe(self, layer) -> None:
609607 new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit.
610608 """
611609 # Lazy import to avoid importing triton too early.
612- from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
613- is_rocm_aiter_moe_enabled ,
614- )
610+ from vllm ._aiter_ops import rocm_aiter_ops
615611 from vllm .model_executor .layers .quantization .utils .flashinfer_utils import (
616612 swap_w13_to_w31 ,
617613 )
618614 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
619- expert_weight_is_col_major ,
620- requant_weight_ue8m0_inplace ,
615+ deepgemm_post_process_fp8_weight_block ,
621616 )
622617 from vllm .utils .deep_gemm import (
623- get_col_major_tma_aligned_tensor ,
624618 is_deep_gemm_e8m0_used ,
625619 )
626620
627- self .rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled ()
621+ self .rocm_aiter_moe_enabled = rocm_aiter_ops . is_fused_moe_enabled ()
628622
629623 assert self .block_quant and self .quant_config .is_checkpoint_fp8_serialized
630624 assert self .quant_config .activation_scheme == "dynamic"
631625
632626 if self .flashinfer_moe_backend is not None :
633- layer .w13_weight .data = swap_w13_to_w31 (layer .w13_weight .data )
634- layer .w13_weight_scale_inv .data = swap_w13_to_w31 (
635- layer .w13_weight_scale_inv .data
636- )
627+ w13_weight = swap_w13_to_w31 (layer .w13_weight .data )
628+ w13_weight_scale_inv = swap_w13_to_w31 (layer .w13_weight_scale_inv .data )
629+ else :
630+ w13_weight = layer .w13_weight .data
631+ w13_weight_scale_inv = layer .w13_weight_scale_inv .data
632+ w2_weight = layer .w2_weight .data
633+ w2_weight_scale_inv = layer .w2_weight_scale_inv .data
637634
638635 # DeepGemm scales need to be transposed and aligned. We try to do
639636 # it ahead of time for performance reasons.
640- if self .allow_deep_gemm and not is_deep_gemm_e8m0_used ():
641- if expert_weight_is_col_major (layer .w13_weight_scale_inv ):
642- layer .w13_weight_scale_inv = get_col_major_tma_aligned_tensor (
643- layer .w13_weight_scale_inv
644- )
645- if expert_weight_is_col_major (layer .w2_weight_scale_inv ):
646- layer .w2_weight_scale_inv = get_col_major_tma_aligned_tensor (
647- layer .w2_weight_scale_inv
648- )
649-
650- if is_deep_gemm_e8m0_used ():
651- assert layer .weight_block_size is not None
652- # Re-quantise the expert weights so their scales are UE8M0.
653- block_sz = tuple (layer .weight_block_size )
654- requant_weight_ue8m0_inplace (
655- layer .w13_weight .data ,
656- layer .w13_weight_scale_inv .data ,
657- block_sz ,
637+ if self .allow_deep_gemm :
638+ w13_weight , w13_weight_scale_inv = deepgemm_post_process_fp8_weight_block (
639+ wq = w13_weight ,
640+ ws = w13_weight_scale_inv ,
641+ quant_block_shape = tuple (layer .weight_block_size ),
642+ use_e8m0 = is_deep_gemm_e8m0_used (),
658643 )
659- requant_weight_ue8m0_inplace (
660- layer .w2_weight .data ,
661- layer .w2_weight_scale_inv .data ,
662- block_sz ,
644+ w2_weight , w2_weight_scale_inv = deepgemm_post_process_fp8_weight_block (
645+ wq = w2_weight ,
646+ ws = w2_weight_scale_inv ,
647+ quant_block_shape = tuple (layer .weight_block_size ),
648+ use_e8m0 = is_deep_gemm_e8m0_used (),
663649 )
664-
665- # Ensure column-major TMA alignment expected by DeepGEMM.
666- if expert_weight_is_col_major (layer .w13_weight_scale_inv ):
667- layer .w13_weight_scale_inv = get_col_major_tma_aligned_tensor (
668- layer .w13_weight_scale_inv
669- )
670- if expert_weight_is_col_major (layer .w2_weight_scale_inv ):
671- layer .w2_weight_scale_inv = get_col_major_tma_aligned_tensor (
672- layer .w2_weight_scale_inv
673- )
650+ layer .w13_weight .copy_ (w13_weight )
651+ layer .w13_weight_scale_inv .copy_ (w13_weight_scale_inv )
652+ layer .w2_weight .copy_ (w2_weight )
653+ layer .w2_weight_scale_inv .copy_ (w2_weight_scale_inv )
674654
675655
676656def process_weights_after_loading_kv (self , layer ) -> None :
0 commit comments