Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def from_local_checkpoint(
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
tensorizer_config_dict: Optional[dict] = None,
ignore_modules: list[str] = [],
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.

Expand All @@ -231,6 +232,16 @@ def from_local_checkpoint(
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[Union[list[str], str]] = []


def filter_ignored_lora_module_keys(modules: dict) -> list[str]:
supported_keys = []
for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
if not any([module_name.startswith(prefix) for prefix in ignore_modules]):
supported_keys.append(lora_module)
return supported_keys

def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
Expand Down Expand Up @@ -260,6 +271,10 @@ def check_unexpected_modules(modules: dict):
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs,
)

# Remove explicitly ignored modules.
tensors = {k: v for k, v in tensors.items() if k in filter_ignored_lora_module_keys(tensors)}
# Check that tensors have only expected LoRA modules.
check_unexpected_modules(tensors)

elif os.path.isfile(lora_tensor_path):
Expand All @@ -271,10 +286,14 @@ def check_unexpected_modules(modules: dict):
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)

# Remove explicitly ignored modules.
tensors = {k: v for k, v in tensors.items() if k in filter_ignored_lora_module_keys(tensors)}
# Check that tensors have only expected LoRA modules.
check_unexpected_modules(tensors)

elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
Expand Down
10 changes: 10 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.module_mapping import MultiModelKeys

logger = init_logger(__name__)

Expand Down Expand Up @@ -113,6 +114,14 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
model = self._adapter_manager.model
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)

# Ignore these modules even if they are found in the LoRA
# checkpoint. (e.g. vision/audio towers in MM models)
ignore_modules = []
if self._adapter_manager.supports_mm:
mm : MultiModelKeys = model.get_mm_mapping()
ignore_modules = mm.connector + mm.tower_model


lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
Expand All @@ -126,6 +135,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper,
ignore_modules=ignore_modules,
)

except FileNotFoundError as e:
Expand Down