|
35 | 35 | ParallelLMHead, |
36 | 36 | VocabParallelEmbedding, |
37 | 37 | ) |
38 | | -from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 38 | +from vllm.model_executor.model_loader.weight_utils import ( |
| 39 | + default_weight_loader, |
| 40 | + maybe_remap_kv_scale_name, |
| 41 | +) |
39 | 42 | from vllm.sequence import IntermediateTensors |
40 | 43 | from vllm.transformers_utils.config import set_default_rope_theta |
41 | 44 |
|
@@ -278,6 +281,7 @@ def __init__( |
278 | 281 | self.scaling, |
279 | 282 | num_kv_heads=self.num_kv_heads, |
280 | 283 | cache_config=cache_config, |
| 284 | + quant_config=quant_config, |
281 | 285 | prefix=f"{prefix}.attn", |
282 | 286 | ) |
283 | 287 | self.key_multiplier = config.key_multiplier |
@@ -360,7 +364,9 @@ def __init__( |
360 | 364 | self.attention_in_multiplier = config.attention_in_multiplier |
361 | 365 | self.attn_out_multiplier = config.attention_out_multiplier |
362 | 366 |
|
363 | | - self.feed_forward = FalconH1MLP(config, prefix=f"{prefix}.feed_forward") |
| 367 | + self.feed_forward = FalconH1MLP( |
| 368 | + config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" |
| 369 | + ) |
364 | 370 |
|
365 | 371 | self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
366 | 372 | self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
@@ -647,6 +653,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
647 | 653 | if "mamba" in name: |
648 | 654 | name = name.replace("mamba", "mamba.mamba") |
649 | 655 |
|
| 656 | + if "scale" in name: |
| 657 | + # Remapping the name of kv-scale. |
| 658 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 659 | + if name is None: |
| 660 | + continue |
| 661 | + |
650 | 662 | for param_name, weight_name, shard_id in stacked_params_mapping: |
651 | 663 | if weight_name not in name: |
652 | 664 | continue |
|
0 commit comments