Skip to content

Commit 78a47f8

Browse files
Test Prompt Embeds/LoRA compatibility and Enable LoRA Support for OPT Models (vllm-project#25717)
Signed-off-by: Andrew Sansom <[email protected]>
1 parent 6a113d9 commit 78a47f8

File tree

5 files changed

+40
-11
lines changed

5 files changed

+40
-11
lines changed

docs/features/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ th:not(:first-child) {
5252
| [mm](multimodal_inputs.md) ||| [🟠](gh-pr:4194)<sup>^</sup> |||||||||| | | |
5353
| best-of |||| [](gh-issue:6137) ||||||| [](gh-issue:7968) ||| | |
5454
| beam-search |||| [](gh-issue:6137) ||||||| [](gh-issue:7968) |||| |
55-
| [prompt-embeds](prompt_embeds.md) || [](gh-issue:25096) | ? ||||||| ? | ? || ? | ? ||
55+
| [prompt-embeds](prompt_embeds.md) || [](gh-issue:25096) | ||||||| | || | ||
5656

5757
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
5858
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ th {
403403
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
404404
| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ |
405405
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
406-
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
406+
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | ✅︎ |
407407
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
408408
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ |
409409
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/entrypoints/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,11 @@ def zephyr_lora_files():
208208
"""Download zephyr LoRA files once per test session."""
209209
from huggingface_hub import snapshot_download
210210
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
211+
212+
213+
@pytest.fixture(scope="session")
214+
def opt125_lora_files() -> str:
215+
"""Download opt-125m LoRA files once per test session."""
216+
from huggingface_hub import snapshot_download
217+
return snapshot_download(
218+
repo_id="peft-internal-testing/opt-125m-dummy-lora")

tests/entrypoints/openai/test_completion_with_prompt_embeds.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import base64
55
import io
6+
import json
67

78
import openai # use the official client for correctness check
89
import pytest
@@ -16,13 +17,15 @@
1617

1718
# any model with a chat template should work here
1819
MODEL_NAME = "facebook/opt-125m"
20+
LORA_SERVING_MODEL_NAME = "opt125m-lora"
1921

2022
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
2123

2224

23-
@pytest.fixture(scope="module")
24-
def default_server_args() -> list[str]:
25-
return [
25+
@pytest.fixture(scope="module", params=["use-lora"])
26+
def default_server_args(request: pytest.FixtureRequest,
27+
opt125_lora_files: str) -> list[str]:
28+
args = [
2629
# use half precision for speed and memory savings in CI environment
2730
"--dtype",
2831
"bfloat16",
@@ -35,6 +38,25 @@ def default_server_args() -> list[str]:
3538
"--enable-prompt-embeds",
3639
]
3740

41+
if request.param == "use-lora":
42+
lora_module_1 = {
43+
"name": LORA_SERVING_MODEL_NAME,
44+
"path": opt125_lora_files,
45+
"base_model_name": MODEL_NAME
46+
}
47+
48+
args.extend([
49+
"--enable-lora",
50+
"--lora-module",
51+
json.dumps(lora_module_1),
52+
"--max-lora-rank",
53+
"64",
54+
"--max-cpu-loras",
55+
"2",
56+
])
57+
58+
return args
59+
3860

3961
EXAMPLE_PROMPTS = [
4062
"Hello, my name is",
@@ -74,7 +96,7 @@ async def client_with_prompt_embeds(server_with_prompt_embeds):
7496

7597

7698
@pytest.mark.asyncio
77-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
99+
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
78100
async def test_completions_with_prompt_embeds(
79101
example_prompt_embeds,
80102
client_with_prompt_embeds: openai.AsyncOpenAI,
@@ -179,7 +201,7 @@ async def test_completions_with_prompt_embeds(
179201

180202

181203
@pytest.mark.asyncio
182-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
204+
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
183205
async def test_completions_errors_with_prompt_embeds(
184206
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
185207
# Test error case: invalid prompt_embeds
@@ -194,7 +216,7 @@ async def test_completions_errors_with_prompt_embeds(
194216

195217
@pytest.mark.asyncio
196218
@pytest.mark.parametrize("logprobs_arg", [1, 0])
197-
@pytest.mark.parametrize("model_name", [MODEL_NAME])
219+
@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
198220
async def test_completions_with_logprobs_and_prompt_embeds(
199221
example_prompt_embeds,
200222
client_with_prompt_embeds: openai.AsyncOpenAI,

vllm/model_executor/models/opt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4444
from vllm.sequence import IntermediateTensors
4545

46-
from .interfaces import SupportsPP
46+
from .interfaces import SupportsLoRA, SupportsPP
4747
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
4848
make_empty_intermediate_tensors_factory, make_layers,
4949
maybe_prefix)
@@ -352,10 +352,9 @@ def load_weights(self, weights: Iterable[tuple[str,
352352
return loaded_params
353353

354354

355-
class OPTForCausalLM(nn.Module, SupportsPP):
355+
class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
356356
packed_modules_mapping = {
357357
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
358-
"gate_up_proj": ["gate_proj", "up_proj"]
359358
}
360359

361360
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={

0 commit comments

Comments
 (0)