Skip to content

Commit e9f568c

Browse files
yiliu30czhu15
authored andcommitted
Porting online convert for llm-compressor (vllm-project#1763)
## Usage ``` export VLLM_HPU_CONVERT_TO_FP8UZ=1 export VLLM_HPU_FORCE_CHANNEL_FP8=1 ``` Original PR HabanaAI#1505 @czhu15 @Wei-Lin-Intel @yangulei Please help review, thx! --------- Signed-off-by: yiliu30 <[email protected]>
1 parent 2a3111b commit e9f568c

File tree

4 files changed

+35
-9
lines changed

4 files changed

+35
-9
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
from typing import Any, Literal, Optional, cast
66

77
import torch
8+
import vllm.envs as envs
89
from compressed_tensors.config import (CompressionFormat,
910
SparsityCompressionConfig,
1011
SparsityStructure)
1112
from compressed_tensors.quantization import (QuantizationArgs,
1213
QuantizationStrategy,
1314
QuantizationType)
14-
from pydantic import BaseModel
1515

16+
from pydantic import BaseModel
1617
from vllm.logger import init_logger
1718
from vllm.model_executor.layers.fused_moe import FusedMoE
1819
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@@ -29,8 +30,11 @@
2930
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
3031
CompressedTensorsWNA16)
3132
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
32-
find_matched_target, is_activation_quantization_format,
33-
should_ignore_layer)
33+
find_matched_target,
34+
is_activation_quantization_format,
35+
should_ignore_layer,
36+
gaudi_weight_wrapper,
37+
)
3438
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
3539
from vllm.platforms import current_platform
3640

@@ -581,6 +585,8 @@ def create_weights(self, layer: torch.nn.Module,
581585
details
582586
"""
583587
weight_loader = extra_weight_attrs.get("weight_loader")
588+
if current_platform.is_hpu() and envs.VLLM_HPU_CONVERT_TO_FP8UZ:
589+
weight_loader = gaudi_weight_wrapper(weight_loader)
584590
layer.scheme.create_weights(
585591
layer=layer,
586592
input_size=input_size,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from vllm.model_executor.utils import set_weight_attrs
3030
from vllm.platforms import current_platform
3131
from vllm.scalar_type import scalar_types
32-
32+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import gaudi_weight_wrapper
3333
logger = init_logger(__name__)
3434

3535

@@ -144,6 +144,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
144144
params_dtype = torch.float8_e4m3fn
145145

146146
# WEIGHTS
147+
if current_platform.is_hpu() and envs.VLLM_HPU_CONVERT_TO_FP8UZ:
148+
extra_weight_attrs["weight_loader"] = gaudi_weight_wrapper(
149+
extra_weight_attrs.get("weight_loader")
150+
)
147151
w13_weight = torch.nn.Parameter(torch.empty(
148152
num_experts,
149153
2 * intermediate_size_per_partition,

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable, Mapping
55
from types import MappingProxyType
66
from typing import Optional
7-
7+
import torch
88
import regex as re
99
from compressed_tensors import CompressionFormat
1010
from torch.nn import Module
@@ -213,3 +213,23 @@ def _match_fused_layer(
213213
unfused_matches.append(None)
214214

215215
return unfused_matches[0] if all(unfused_matches) else None
216+
217+
def gaudi_weight_wrapper(weight_loader):
218+
"""Wrapper for Gaudi weight conversion."""
219+
220+
FP8_SCALE_FACTOR = 2.0
221+
def wrapper(*args, **kwargs):
222+
# args[0] is parameter, args[1] is loaded_weight
223+
# weights will be always in fp8, but scales will be in fp32,
224+
# so we can detect it by dtype
225+
loaded_weight = args[1]
226+
if loaded_weight.dtype == torch.float8_e4m3fn:
227+
loaded_weight.data = (
228+
loaded_weight.data.float() / FP8_SCALE_FACTOR
229+
).to(torch.float8_e4m3fn)
230+
else:
231+
loaded_weight.data = (loaded_weight.data * FP8_SCALE_FACTOR)
232+
args = (args[0], loaded_weight) + args[2:]
233+
weight_loader(*args, **kwargs)
234+
235+
return wrapper

vllm/model_executor/models/glm4_moe.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
390390
# enable_eplb = vllm_config.parallel_config.enable_eplb
391391
enable_eplb = False
392392
self.config = config
393-
394393
self.vocab_size = config.vocab_size
395394

396395
if get_pp_group().is_first_rank:
@@ -511,7 +510,6 @@ def load_weights(self, weights: Iterable[tuple[str,
511510
continue
512511
if is_pp_missing_parameter(name, self):
513512
continue
514-
515513
param = params_dict[name]
516514
weight_loader = param.weight_loader
517515
weight_loader(param, loaded_weight, shard_id)
@@ -533,7 +531,6 @@ def load_weights(self, weights: Iterable[tuple[str,
533531

534532
if is_pp_missing_parameter(name_mapped, self):
535533
continue
536-
537534
param = params_dict[name_mapped]
538535
# We should ask the weight loader to return success or not
539536
# here since otherwise we may skip experts with other
@@ -565,7 +562,6 @@ def load_weights(self, weights: Iterable[tuple[str,
565562

566563
if is_pp_missing_parameter(name, self):
567564
continue
568-
569565
param = params_dict[name]
570566
weight_loader = getattr(param, "weight_loader",
571567
default_weight_loader)

0 commit comments

Comments
 (0)