diff --git a/vllm/lora/models.py b/vllm/lora/models.py index edf34b483e9a..c010fbb8f829 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -205,6 +205,7 @@ def from_local_checkpoint( embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, tensorizer_config_dict: Optional[dict] = None, + ignore_modules: list[str] = [], ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -231,6 +232,16 @@ def from_local_checkpoint( tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[Union[list[str], str]] = [] + + def filter_ignored_lora_module_keys(modules: dict) -> list[str]: + supported_keys = [] + for lora_module in modules.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + if not any([module_name.startswith(prefix) for prefix in ignore_modules]): + supported_keys.append(lora_module) + return supported_keys + def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa module_name, _, _ = parse_fine_tuned_lora_name( @@ -260,6 +271,10 @@ def check_unexpected_modules(modules: dict): dtype=tensorizer_config.dtype, **tensorizer_args.deserialization_kwargs, ) + + # Remove explicitly ignored modules. + tensors = {k: v for k, v in tensors.items() if k in filter_ignored_lora_module_keys(tensors)} + # Check that tensors have only expected LoRA modules. check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): @@ -271,10 +286,14 @@ def check_unexpected_modules(modules: dict): # the target_modules of the adapter_config.json. unexpected_modules = [] with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore - # Load tensors if there are only expected modules. - check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) + + # Remove explicitly ignored modules. + tensors = {k: v for k, v in tensors.items() if k in filter_ignored_lora_module_keys(tensors)} + # Check that tensors have only expected LoRA modules. + check_unexpected_modules(tensors) + elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): # When a bin/pt file is provided, we rely on config to find # unexpected modules. diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 3ca819fb732c..c6065476d800 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -17,6 +17,7 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path +from vllm.model_executor.models.module_mapping import MultiModelKeys logger = init_logger(__name__) @@ -113,6 +114,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: model = self._adapter_manager.model hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + # Ignore these modules even if they are found in the LoRA + # checkpoint. (e.g. vision/audio towers in MM models) + ignore_modules = [] + if self._adapter_manager.supports_mm: + mm : MultiModelKeys = model.get_mm_mapping() + ignore_modules = mm.connector + mm.tower_model + + lora = self._lora_model_cls.from_local_checkpoint( lora_path, expected_lora_modules, @@ -126,6 +135,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, weights_mapper=hf_to_vllm_mapper, + ignore_modules=ignore_modules, ) except FileNotFoundError as e: