Skip to content

Commit 8b03c85

Browse files
yiliu30ranzhejiangYi
authored
Speedup online convert (#1772)
Porting the left part of #1505 - GLM-4.5-Air-FP8( ~100B), 600s -> 60.50 s - For DS R1(~600B), it requires about 430s w/ that fix @czhu15 @yangulei Please help to review, thanks! cc @thuang6 --------- Signed-off-by: yiliu30 <[email protected]> Co-authored-by: ranzhejiang <[email protected]> Co-authored-by: Yi <[email protected]>
1 parent 0cd2bc6 commit 8b03c85

File tree

7 files changed

+86
-85
lines changed

7 files changed

+86
-85
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
find_matched_target,
3434
is_activation_quantization_format,
3535
should_ignore_layer,
36-
gaudi_weight_wrapper,
3736
)
37+
from vllm.model_executor.model_loader.weight_utils import gaudi_weight_wrapper
3838
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
3939
from vllm.platforms import current_platform
4040

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from vllm.model_executor.utils import set_weight_attrs
3030
from vllm.platforms import current_platform
3131
from vllm.scalar_type import scalar_types
32-
from vllm.model_executor.layers.quantization.compressed_tensors.utils import gaudi_weight_wrapper
32+
from vllm.model_executor.model_loader.weight_utils import gaudi_weight_wrapper
33+
3334
logger = init_logger(__name__)
3435

3536

@@ -146,8 +147,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
146147
# WEIGHTS
147148
if current_platform.is_hpu() and envs.VLLM_HPU_CONVERT_TO_FP8UZ:
148149
extra_weight_attrs["weight_loader"] = gaudi_weight_wrapper(
149-
extra_weight_attrs.get("weight_loader")
150-
)
150+
extra_weight_attrs.get("weight_loader"))
151151
w13_weight = torch.nn.Parameter(torch.empty(
152152
num_experts,
153153
2 * intermediate_size_per_partition,

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Iterable, Mapping
55
from types import MappingProxyType
66
from typing import Optional
7-
import torch
87
import regex as re
98
from compressed_tensors import CompressionFormat
109
from torch.nn import Module
@@ -213,23 +212,3 @@ def _match_fused_layer(
213212
unfused_matches.append(None)
214213

215214
return unfused_matches[0] if all(unfused_matches) else None
216-
217-
def gaudi_weight_wrapper(weight_loader):
218-
"""Wrapper for Gaudi weight conversion."""
219-
220-
FP8_SCALE_FACTOR = 2.0
221-
def wrapper(*args, **kwargs):
222-
# args[0] is parameter, args[1] is loaded_weight
223-
# weights will be always in fp8, but scales will be in fp32,
224-
# so we can detect it by dtype
225-
loaded_weight = args[1]
226-
if loaded_weight.dtype == torch.float8_e4m3fn:
227-
loaded_weight.data = (
228-
loaded_weight.data.float() / FP8_SCALE_FACTOR
229-
).to(torch.float8_e4m3fn)
230-
else:
231-
loaded_weight.data = (loaded_weight.data * FP8_SCALE_FACTOR)
232-
args = (args[0], loaded_weight) + args[2:]
233-
weight_loader(*args, **kwargs)
234-
235-
return wrapper

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.model_executor.utils import set_weight_attrs
4040
from vllm.platforms import current_platform
4141
from vllm.scalar_type import scalar_types
42-
42+
from vllm.model_executor.model_loader.weight_utils import gaudi_weight_wrapper
4343
if current_platform.is_hpu():
4444
import vllm_hpu_extension.ops as hpu_ops
4545
from vllm_hpu_extension.ops import scaled_fp8_quant
@@ -228,7 +228,7 @@ def create_weights(
228228
layer.orig_dtype = params_dtype
229229
layer.weight_block_size = None
230230
if current_platform.is_hpu() and envs.VLLM_HPU_CONVERT_TO_FP8UZ:
231-
weight_loader = self._gaudi_weight_wrapper(weight_loader)
231+
weight_loader = gaudi_weight_wrapper(weight_loader)
232232

233233
if self.block_quant:
234234
tp_size = get_tensor_model_parallel_world_size()
@@ -312,25 +312,6 @@ def create_weights(
312312
else:
313313
layer.register_parameter("input_scale", None)
314314

315-
def _gaudi_weight_wrapper(self, weight_loader):
316-
"""Wrapper for Gaudi weight conversion."""
317-
318-
def wrapper(*args, **kwargs):
319-
# args[0] is parameter, args[1] is loaded_weight
320-
# weights will be always in fp8, but scales will be in fp32,
321-
# so we can detect it by dtype
322-
loaded_weight = args[1]
323-
if loaded_weight.dtype == torch.float8_e4m3fn:
324-
loaded_weight = (loaded_weight.float() * 0.5).to(
325-
torch.float8_e4m3fn)
326-
else:
327-
loaded_weight = (loaded_weight.data * 2.0)
328-
args = (args[0], loaded_weight) + args[2:]
329-
330-
weight_loader(*args, **kwargs)
331-
332-
return wrapper
333-
334315
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
335316
# Pad the weight tensor. This is an optimization on ROCm platform, which
336317
# can benefit from tensors located far enough from one another in memory
@@ -541,7 +522,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
541522
layer.weight_block_size = None
542523
layer.weight_block_size = None
543524
if current_platform.is_hpu() and envs.VLLM_HPU_CONVERT_TO_FP8UZ:
544-
extra_weight_attrs["weight_loader"] = self._gaudi_weight_wrapper(
525+
extra_weight_attrs["weight_loader"] = gaudi_weight_wrapper(
545526
extra_weight_attrs.get("weight_loader"))
546527
layer.quant_config = self.quant_config
547528
if self.quant_config.is_checkpoint_fp8_serialized:
@@ -662,24 +643,6 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
662643
layer.w13_input_scale = None
663644
layer.w2_input_scale = None
664645

665-
def _gaudi_weight_wrapper(self, weight_loader):
666-
"""Wrapper for Gaudi weight conversion."""
667-
668-
def wrapper(*args, **kwargs):
669-
# args[0] is parameter, args[1] is loaded_weight
670-
# weights will be always in fp8, but scales will be in fp32,
671-
# so we can detect it by dtype
672-
loaded_weight = args[1]
673-
if loaded_weight.dtype == torch.float8_e4m3fn:
674-
loaded_weight.data = (loaded_weight.data.float() * 0.5).to(
675-
torch.float8_e4m3fn)
676-
else:
677-
loaded_weight.data = (loaded_weight.data * 2.0)
678-
args = (args[0], loaded_weight) + args[2:]
679-
weight_loader(*args, **kwargs)
680-
681-
return wrapper
682-
683646
def process_weights_after_loading(self, layer: Module) -> None:
684647
# Lazy import to avoid importing triton too early.
685648
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import Generator
1313
from pathlib import Path
1414
from typing import Any, Callable, Optional, Union
15-
15+
from functools import wraps
1616
import filelock
1717
import gguf
1818
import huggingface_hub.constants
@@ -29,6 +29,7 @@
2929
get_quantization_config)
3030
from vllm.platforms import current_platform
3131
from vllm.utils import PlaceholderModule
32+
import vllm.envs as envs
3233

3334
try:
3435
from runai_model_streamer import SafetensorsStreamer
@@ -788,3 +789,69 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
788789

789790
# If there were no matches, return the untouched param name
790791
return name
792+
793+
794+
def gaudi_weight_wrapper(weight_loader):
795+
"""Wrapper for Gaudi weight conversion."""
796+
797+
FP8_SCALE_FACTOR = 2.0
798+
799+
def wrapper(*args, **kwargs):
800+
# args[0] is parameter, args[1] is loaded_weight
801+
# weights will be always in fp8, but scales will be in fp32,
802+
# so we can detect it by dtype
803+
loaded_weight = args[1]
804+
if loaded_weight.dtype == torch.float8_e4m3fn:
805+
loaded_weight.data = (loaded_weight.data.float() /
806+
FP8_SCALE_FACTOR).to(torch.float8_e4m3fn)
807+
else:
808+
loaded_weight.data = (loaded_weight.data * FP8_SCALE_FACTOR)
809+
args = (args[0], loaded_weight) + args[2:]
810+
weight_loader(*args, **kwargs)
811+
812+
return wrapper
813+
814+
815+
def with_thread_limits(div_omp: int = 4, div_torch: int = 8):
816+
"""
817+
Decorator to temporarily set OMP_NUM_THREADS and PyTorch threads,
818+
and restore them after the function call.
819+
820+
Args:
821+
div_omp: divide CPU cores by this for OMP_NUM_THREADS
822+
div_torch: divide CPU cores by this for torch.set_num_threads
823+
"""
824+
825+
def decorator(func):
826+
827+
@wraps(func)
828+
def wrapper(*args, **kwargs):
829+
if not (current_platform.is_hpu()
830+
and envs.VLLM_HPU_CONVERT_TO_FP8UZ):
831+
return func(*args, **kwargs)
832+
833+
# Save original settings
834+
old_omp = os.environ.get("OMP_NUM_THREADS", None)
835+
old_torch = torch.get_num_threads()
836+
num_cores = os.cpu_count() or 1
837+
838+
# Set new limits
839+
os.environ["OMP_NUM_THREADS"] = str(max(1, num_cores // div_omp))
840+
torch.set_num_threads(max(1, num_cores // div_torch))
841+
logger.warning_once(
842+
"Setting OMP_NUM_THREADS to %s and torch.set_num_threads to %s",
843+
os.environ["OMP_NUM_THREADS"], torch.get_num_threads())
844+
try:
845+
# Call the actual function
846+
return func(*args, **kwargs)
847+
finally:
848+
# Restore original settings
849+
if old_omp is None:
850+
os.environ.pop("OMP_NUM_THREADS", None)
851+
else:
852+
os.environ["OMP_NUM_THREADS"] = old_omp
853+
torch.set_num_threads(old_torch)
854+
855+
return wrapper
856+
857+
return decorator

vllm/model_executor/models/deepseek_v2.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from vllm.model_executor.layers.vocab_parallel_embedding import (
4848
ParallelLMHead, VocabParallelEmbedding)
4949
from vllm.model_executor.model_loader.weight_utils import (
50-
default_weight_loader, maybe_remap_kv_scale_name)
50+
default_weight_loader, maybe_remap_kv_scale_name, with_thread_limits)
5151
from vllm.model_executor.sampling_metadata import SamplingMetadata
5252
from vllm.platforms import current_platform
5353
from vllm.sequence import IntermediateTensors
@@ -781,6 +781,7 @@ def make_empty_intermediate_tensors(
781781
device=device),
782782
})
783783

784+
@with_thread_limits()
784785
def load_weights(self, weights: Iterable[tuple[str,
785786
torch.Tensor]]) -> set[str]:
786787
stacked_params_mapping = [
@@ -796,12 +797,6 @@ def load_weights(self, weights: Iterable[tuple[str,
796797
ckpt_down_proj_name="down_proj",
797798
ckpt_up_proj_name="up_proj",
798799
num_experts=self.config.n_routed_experts)
799-
if current_platform.is_hpu():
800-
old_num_threads = torch.get_num_threads()
801-
import os
802-
num_cores = os.cpu_count()
803-
os.environ["OMP_NUM_THREADS"] = str(max(1, num_cores // 4))
804-
torch.set_num_threads(max(1, num_cores // 8))
805800

806801
params_dict = dict(self.named_parameters())
807802
loaded_params: set[str] = set()
@@ -873,10 +868,6 @@ def load_weights(self, weights: Iterable[tuple[str,
873868
default_weight_loader)
874869
weight_loader(param, loaded_weight)
875870
loaded_params.add(name)
876-
if current_platform.is_hpu():
877-
# Restore the number of threads for HPU.
878-
torch.set_num_threads(old_num_threads)
879-
os.environ["OMP_NUM_THREADS"] = str(old_num_threads)
880871
return loaded_params
881872

882873

vllm/model_executor/models/glm4_moe.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from vllm.attention import Attention
3434
from vllm.compilation.decorators import support_torch_compile
35-
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
35+
from vllm.config import CacheConfig, VllmConfig
3636
from vllm.distributed import (get_ep_group, get_pp_group,
3737
get_tensor_model_parallel_world_size)
3838
from vllm.logger import init_logger
@@ -49,7 +49,10 @@
4949
from vllm.model_executor.layers.vocab_parallel_embedding import (
5050
ParallelLMHead, VocabParallelEmbedding)
5151
from vllm.model_executor.model_loader.weight_utils import (
52-
default_weight_loader, maybe_remap_kv_scale_name)
52+
default_weight_loader,
53+
maybe_remap_kv_scale_name,
54+
with_thread_limits,
55+
)
5356
from vllm.model_executor.sampling_metadata import SamplingMetadata
5457
from vllm.sequence import IntermediateTensors
5558

@@ -130,11 +133,9 @@ def __init__(
130133
torch.empty(config.n_routed_experts, dtype=torch.float32))
131134

132135
# Load balancing settings.
133-
vllm_config = get_current_vllm_config()
134-
parallel_config = vllm_config.parallel_config
135136
self.enable_eplb = enable_eplb
136137

137-
# Comment code below until we rebase to latset vllm
138+
# Comment code below until we rebase to latest vllm
138139
# self.n_redundant_experts = parallel_config.num_redundant_experts
139140
self.n_redundant_experts = 0
140141
self.n_logical_experts = self.n_routed_experts
@@ -161,10 +162,9 @@ def __init__(
161162
prefix=f"{prefix}.experts",
162163
scoring_func="sigmoid",
163164
e_score_correction_bias=self.gate.e_score_correction_bias,
164-
# Comment code below until we rebase to latset vllm
165+
# Comment code below until we rebase to latest vllm
165166
# enable_eplb=self.enable_eplb,
166167
# num_redundant_experts=self.n_redundant_experts
167-
168168
)
169169

170170
if config.n_shared_experts is not None:
@@ -386,7 +386,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
386386
config = vllm_config.model_config.hf_config
387387
cache_config = vllm_config.cache_config
388388
quant_config = vllm_config.quant_config
389-
# comment code below until we rebase to latset vllm
389+
# comment code below until we rebase to latest vllm
390390
# enable_eplb = vllm_config.parallel_config.enable_eplb
391391
enable_eplb = False
392392
self.config = config
@@ -673,6 +673,7 @@ def compute_logits(
673673
sampling_metadata)
674674
return logits
675675

676+
@with_thread_limits()
676677
def load_weights(self, weights: Iterable[tuple[str,
677678
torch.Tensor]]) -> set[str]:
678679
loader = AutoWeightsLoader(self)

0 commit comments

Comments
 (0)