Skip to content

Commit 15d76f7

Browse files
Revert "[Misc] Enable weights loading tracking for quantized models" (vllm-project#35309)
1 parent 8fd6975 commit 15d76f7

File tree

1 file changed

+4
-15
lines changed

1 file changed

+4
-15
lines changed

vllm/model_executor/model_loader/default_loader.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from vllm.config import ModelConfig
1515
from vllm.config.load import LoadConfig
1616
from vllm.logger import init_logger
17-
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1817
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
1918
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
2019
from vllm.model_executor.model_loader.weight_utils import (
@@ -287,6 +286,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
287286
):
288287
self.load_config.safetensors_load_strategy = "torchao"
289288

289+
weights_to_load = {name for name, _ in model.named_parameters()}
290290
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
291291

292292
self.counter_after_loading_weights = time.perf_counter()
@@ -295,20 +295,9 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
295295
self.counter_after_loading_weights - self.counter_before_loading_weights,
296296
scope="local",
297297
)
298-
self.track_weights_loading(model, loaded_weights)
299-
300-
def track_weights_loading(
301-
self, model: nn.Module, loaded_weights: set[str] | None
302-
) -> None:
303-
weights_to_load = {name for name, _ in model.named_parameters()}
304-
if loaded_weights is not None:
305-
for name, module in model.named_modules():
306-
quant_method = getattr(module, "quant_method", None)
307-
# ignore kv_cache scale, which can be missing in checkpoints
308-
if isinstance(quant_method, BaseKVCacheMethod):
309-
for param_name, _ in module.named_parameters():
310-
full_name = f"{name}.{param_name}" if name else param_name
311-
loaded_weights.add(full_name)
298+
# We only enable strict check for non-quantized models
299+
# that have loaded weights tracking currently.
300+
if model_config.quantization is None and loaded_weights is not None:
312301
weights_not_loaded = weights_to_load - loaded_weights
313302
if weights_not_loaded:
314303
raise ValueError(

0 commit comments

Comments
 (0)