Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,6 @@ def setUp(self):
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True
self.config.explorer.rollout_model.enable_lora = True
self.config.explorer.rollout_model.enable_runtime_lora_updating = True

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
Expand Down Expand Up @@ -1345,3 +1344,68 @@ async def test_tinker_api(self):
self.assertEqual(response.sequences[0].stop_reason, "length")
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
self.assertIsNone(response.topk_prompt_logprobs)

# test add remove lora
from vllm.lora.request import LoRARequest

# create a dummy lora adapter with all zero weights
lora_path_1 = os.path.join(self.config.checkpoint_job_dir, "adapter_1")
lora_path_2 = os.path.join(self.config.checkpoint_job_dir, "adapter_2")
_create_adapter(self.config.model.model_path, lora_path_1, "adapter_1")
_create_adapter(self.config.model.model_path, lora_path_2, "adapter_2")
lora_1 = LoRARequest(
lora_name="test_adapter_1",
lora_int_id=1,
lora_path=os.path.join(lora_path_1, "adapter_1"),
)
lora_2 = LoRARequest(
lora_name="test_adapter_2",
lora_int_id=2,
lora_path=os.path.join(lora_path_2, "adapter_2"),
)
response = await engine.sample.remote(
prompt=prompt,
num_samples=1,
sampling_params=types.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
lora_request=lora_1,
)
ids = await engine.list_lora_adapters.remote()
self.assertEqual(ids, [1])
self.assertEqual(len(response.sequences), 1)
self.assertEqual(response.sequences[0].stop_reason, "length")
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
self.assertIsNone(response.topk_prompt_logprobs)
response = await engine.sample.remote(
prompt=prompt,
num_samples=1,
sampling_params=types.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
lora_request=lora_2,
)
self.assertEqual(len(response.sequences), 1)
self.assertEqual(response.sequences[0].stop_reason, "length")
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
self.assertIsNone(response.topk_prompt_logprobs)
await engine.remove_lora_adapter.remote(lora_id=1)
await engine.remove_lora_adapter.remote(lora_id=2)
ids = await engine.list_lora_adapters.remote()
self.assertEqual(ids, [])


def _create_adapter(model_path: str, lora_path: str, name: str):
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cpu",
)
lora_config = LoraConfig(
r=8,
lora_alpha=8,
target_modules=["gate_proj", "up_proj", "down_proj"],
lora_dropout=0.1,
)
lora_model = get_peft_model(model, lora_config, adapter_name=name)
lora_model.save_pretrained(lora_path)
30 changes: 30 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,36 @@ async def logprobs( # type: ignore [override]
dtype=torch.float32,
)

async def add_lora_adapter(self, lora_request: Any) -> int:
"""Add a LoRA adapter to the vLLM engine.

Args:
lora_request (LoRARequest): The LoRA request.

Returns:
lora_id (int): The LoRA adapter ID.
"""
lora_id = await self.async_llm.add_lora(lora_request)
return lora_id

async def remove_lora_adapter(self, lora_id: int) -> None:
"""Remove a LoRA adapter from the vLLM engine.

Args:
lora_id (int): The LoRA adapter ID.
"""
await self.async_llm.remove_lora(lora_id)

async def list_lora_adapters(self) -> Sequence[int]:
"""List all LoRA adapter IDs in the vLLM engine.

Returns:
lora_ids (List[int]): The list of LoRA adapter IDs.
"""
lora_ids = await self.async_llm.list_loras()
print("Get lora ids from vLLM:", lora_ids)
return list(lora_ids)

async def sample(
self,
prompt: Any,
Expand Down