Skip to content

Commit f1cb9b5

Browse files
Fix quantized Falcon-H1 model loading issues (vllm-project#32728)
Signed-off-by: Shengliang Xu <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 4c4b6f7 commit f1cb9b5

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

vllm/model_executor/models/falcon_h1.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
ParallelLMHead,
3636
VocabParallelEmbedding,
3737
)
38-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38+
from vllm.model_executor.model_loader.weight_utils import (
39+
default_weight_loader,
40+
maybe_remap_kv_scale_name,
41+
)
3942
from vllm.sequence import IntermediateTensors
4043
from vllm.transformers_utils.config import set_default_rope_theta
4144

@@ -278,6 +281,7 @@ def __init__(
278281
self.scaling,
279282
num_kv_heads=self.num_kv_heads,
280283
cache_config=cache_config,
284+
quant_config=quant_config,
281285
prefix=f"{prefix}.attn",
282286
)
283287
self.key_multiplier = config.key_multiplier
@@ -360,7 +364,9 @@ def __init__(
360364
self.attention_in_multiplier = config.attention_in_multiplier
361365
self.attn_out_multiplier = config.attention_out_multiplier
362366

363-
self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward")
367+
self.feed_forward = FalconH1MLP(
368+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
369+
)
364370

365371
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
366372
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -647,6 +653,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
647653
if "mamba" in name:
648654
name = name.replace("mamba", "mamba.mamba")
649655

656+
if "scale" in name:
657+
# Remapping the name of kv-scale.
658+
name = maybe_remap_kv_scale_name(name, params_dict)
659+
if name is None:
660+
continue
661+
650662
for param_name, weight_name, shard_id in stacked_params_mapping:
651663
if weight_name not in name:
652664
continue

0 commit comments

Comments
 (0)