|
12 | 12 | from transformers.utils import SAFE_WEIGHTS_INDEX_NAME |
13 | 13 |
|
14 | 14 | from vllm import envs |
15 | | -from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig |
| 15 | +from vllm.config import LoadConfig, LoadFormat, ModelConfig |
16 | 16 | from vllm.logger import init_logger |
17 | 17 | from vllm.model_executor.model_loader.base_loader import BaseModelLoader |
18 | | -from vllm.model_executor.model_loader.utils import ( |
19 | | - initialize_model, process_weights_after_loading, set_default_torch_dtype) |
20 | 18 | from vllm.model_executor.model_loader.weight_utils import ( |
21 | 19 | download_safetensors_index_file_from_hf, download_weights_from_hf, |
22 | 20 | fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, |
@@ -264,32 +262,20 @@ def download_model(self, model_config: ModelConfig) -> None: |
264 | 262 | fall_back_to_pt=True, |
265 | 263 | allow_patterns_overrides=None) |
266 | 264 |
|
267 | | - def load_model(self, vllm_config: VllmConfig, |
268 | | - model_config: ModelConfig) -> nn.Module: |
269 | | - device_config = vllm_config.device_config |
270 | | - target_device = torch.device(device_config.device) |
271 | | - with set_default_torch_dtype(model_config.dtype): |
272 | | - with target_device: |
273 | | - model = initialize_model(vllm_config=vllm_config, |
274 | | - model_config=model_config) |
275 | | - |
276 | | - weights_to_load = {name for name, _ in model.named_parameters()} |
277 | | - loaded_weights = model.load_weights( |
278 | | - self.get_all_weights(model_config, model)) |
279 | | - self.counter_after_loading_weights = time.perf_counter() |
280 | | - logger.info( |
281 | | - "Loading weights took %.2f seconds", |
282 | | - self.counter_after_loading_weights - |
283 | | - self.counter_before_loading_weights) |
284 | | - # We only enable strict check for non-quantized models |
285 | | - # that have loaded weights tracking currently. |
286 | | - if model_config.quantization is None and loaded_weights is not None: |
287 | | - weights_not_loaded = weights_to_load - loaded_weights |
288 | | - if weights_not_loaded: |
289 | | - raise ValueError( |
290 | | - "Following weights were not initialized from " |
291 | | - f"checkpoint: {weights_not_loaded}") |
292 | | - |
293 | | - process_weights_after_loading(model, model_config, target_device) |
294 | | - |
295 | | - return model.eval() |
| 265 | + def load_weights(self, model: nn.Module, |
| 266 | + model_config: ModelConfig) -> None: |
| 267 | + weights_to_load = {name for name, _ in model.named_parameters()} |
| 268 | + loaded_weights = model.load_weights( |
| 269 | + self.get_all_weights(model_config, model)) |
| 270 | + self.counter_after_loading_weights = time.perf_counter() |
| 271 | + logger.info( |
| 272 | + "Loading weights took %.2f seconds", |
| 273 | + self.counter_after_loading_weights - |
| 274 | + self.counter_before_loading_weights) |
| 275 | + # We only enable strict check for non-quantized models |
| 276 | + # that have loaded weights tracking currently. |
| 277 | + if model_config.quantization is None and loaded_weights is not None: |
| 278 | + weights_not_loaded = weights_to_load - loaded_weights |
| 279 | + if weights_not_loaded: |
| 280 | + raise ValueError("Following weights were not initialized from " |
| 281 | + f"checkpoint: {weights_not_loaded}") |
0 commit comments