Skip to content

Commit 543750e

Browse files
committed
add/remove/list lora
1 parent c5effdb commit 543750e

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

tests/common/vllm_test.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,6 @@ def setUp(self):
12701270
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
12711271
self.config.explorer.rollout_model.enable_openai_api = True
12721272
self.config.explorer.rollout_model.enable_lora = True
1273-
self.config.explorer.rollout_model.enable_runtime_lora_updating = True
12741273

12751274
self.config.check_and_update()
12761275
self.engines, self.auxiliary_engines = create_inference_models(self.config)
@@ -1345,3 +1344,68 @@ async def test_tinker_api(self):
13451344
self.assertEqual(response.sequences[0].stop_reason, "length")
13461345
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
13471346
self.assertIsNone(response.topk_prompt_logprobs)
1347+
1348+
# test add remove lora
1349+
from vllm.lora.request import LoRARequest
1350+
1351+
# create a dummy lora adapter with all zero weights
1352+
lora_path_1 = os.path.join(self.config.checkpoint_job_dir, "adapter_1")
1353+
lora_path_2 = os.path.join(self.config.checkpoint_job_dir, "adapter_2")
1354+
_create_adapter(self.config.model.model_path, lora_path_1, "adapter_1")
1355+
_create_adapter(self.config.model.model_path, lora_path_2, "adapter_2")
1356+
lora_1 = LoRARequest(
1357+
lora_name="test_adapter_1",
1358+
lora_int_id=1,
1359+
lora_path=os.path.join(lora_path_1, "adapter_1"),
1360+
)
1361+
lora_2 = LoRARequest(
1362+
lora_name="test_adapter_2",
1363+
lora_int_id=2,
1364+
lora_path=os.path.join(lora_path_2, "adapter_2"),
1365+
)
1366+
response = await engine.sample.remote(
1367+
prompt=prompt,
1368+
num_samples=1,
1369+
sampling_params=types.SamplingParams(max_tokens=1),
1370+
include_prompt_logprobs=True,
1371+
lora_request=lora_1,
1372+
)
1373+
ids = await engine.list_lora_adapters.remote()
1374+
self.assertEqual(ids, [1])
1375+
self.assertEqual(len(response.sequences), 1)
1376+
self.assertEqual(response.sequences[0].stop_reason, "length")
1377+
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
1378+
self.assertIsNone(response.topk_prompt_logprobs)
1379+
response = await engine.sample.remote(
1380+
prompt=prompt,
1381+
num_samples=1,
1382+
sampling_params=types.SamplingParams(max_tokens=1),
1383+
include_prompt_logprobs=True,
1384+
lora_request=lora_2,
1385+
)
1386+
self.assertEqual(len(response.sequences), 1)
1387+
self.assertEqual(response.sequences[0].stop_reason, "length")
1388+
self.assertEqual(len(prompt.to_ints()), len(response.prompt_logprobs))
1389+
self.assertIsNone(response.topk_prompt_logprobs)
1390+
await engine.remove_lora_adapter.remote(lora_id=1)
1391+
await engine.remove_lora_adapter.remote(lora_id=2)
1392+
ids = await engine.list_lora_adapters.remote()
1393+
self.assertEqual(ids, [])
1394+
1395+
1396+
def _create_adapter(model_path: str, lora_path: str, name: str):
1397+
from peft import LoraConfig, get_peft_model
1398+
from transformers import AutoModelForCausalLM
1399+
1400+
model = AutoModelForCausalLM.from_pretrained(
1401+
model_path,
1402+
device_map="cpu",
1403+
)
1404+
lora_config = LoraConfig(
1405+
r=8,
1406+
lora_alpha=8,
1407+
target_modules=["gate_proj", "up_proj", "down_proj"],
1408+
lora_dropout=0.1,
1409+
)
1410+
lora_model = get_peft_model(model, lora_config, adapter_name=name)
1411+
lora_model.save_pretrained(lora_path)

trinity/common/models/vllm_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,36 @@ async def logprobs( # type: ignore [override]
403403
dtype=torch.float32,
404404
)
405405

406+
async def add_lora_adapter(self, lora_request: Any) -> int:
407+
"""Add a LoRA adapter to the vLLM engine.
408+
409+
Args:
410+
lora_request (LoRARequest): The LoRA request.
411+
412+
Returns:
413+
lora_id (int): The LoRA adapter ID.
414+
"""
415+
lora_id = await self.async_llm.add_lora(lora_request)
416+
return lora_id
417+
418+
async def remove_lora_adapter(self, lora_id: int) -> None:
419+
"""Remove a LoRA adapter from the vLLM engine.
420+
421+
Args:
422+
lora_id (int): The LoRA adapter ID.
423+
"""
424+
await self.async_llm.remove_lora(lora_id)
425+
426+
async def list_lora_adapters(self) -> Sequence[int]:
427+
"""List all LoRA adapter IDs in the vLLM engine.
428+
429+
Returns:
430+
lora_ids (List[int]): The list of LoRA adapter IDs.
431+
"""
432+
lora_ids = await self.async_llm.list_loras()
433+
print("Get lora ids from vLLM:", lora_ids)
434+
return list(lora_ids)
435+
406436
async def sample(
407437
self,
408438
prompt: Any,

0 commit comments

Comments
 (0)