@@ -1270,7 +1270,6 @@ def setUp(self):
12701270 self .config .explorer .rollout_model .chat_template = CHAT_TEMPLATE
12711271 self .config .explorer .rollout_model .enable_openai_api = True
12721272 self .config .explorer .rollout_model .enable_lora = True
1273- self .config .explorer .rollout_model .enable_runtime_lora_updating = True
12741273
12751274 self .config .check_and_update ()
12761275 self .engines , self .auxiliary_engines = create_inference_models (self .config )
@@ -1345,3 +1344,68 @@ async def test_tinker_api(self):
13451344 self .assertEqual (response .sequences [0 ].stop_reason , "length" )
13461345 self .assertEqual (len (prompt .to_ints ()), len (response .prompt_logprobs ))
13471346 self .assertIsNone (response .topk_prompt_logprobs )
1347+
1348+ # test add remove lora
1349+ from vllm .lora .request import LoRARequest
1350+
1351+ # create a dummy lora adapter with all zero weights
1352+ lora_path_1 = os .path .join (self .config .checkpoint_job_dir , "adapter_1" )
1353+ lora_path_2 = os .path .join (self .config .checkpoint_job_dir , "adapter_2" )
1354+ _create_adapter (self .config .model .model_path , lora_path_1 , "adapter_1" )
1355+ _create_adapter (self .config .model .model_path , lora_path_2 , "adapter_2" )
1356+ lora_1 = LoRARequest (
1357+ lora_name = "test_adapter_1" ,
1358+ lora_int_id = 1 ,
1359+ lora_path = os .path .join (lora_path_1 , "adapter_1" ),
1360+ )
1361+ lora_2 = LoRARequest (
1362+ lora_name = "test_adapter_2" ,
1363+ lora_int_id = 2 ,
1364+ lora_path = os .path .join (lora_path_2 , "adapter_2" ),
1365+ )
1366+ response = await engine .sample .remote (
1367+ prompt = prompt ,
1368+ num_samples = 1 ,
1369+ sampling_params = types .SamplingParams (max_tokens = 1 ),
1370+ include_prompt_logprobs = True ,
1371+ lora_request = lora_1 ,
1372+ )
1373+ ids = await engine .list_lora_adapters .remote ()
1374+ self .assertEqual (ids , [1 ])
1375+ self .assertEqual (len (response .sequences ), 1 )
1376+ self .assertEqual (response .sequences [0 ].stop_reason , "length" )
1377+ self .assertEqual (len (prompt .to_ints ()), len (response .prompt_logprobs ))
1378+ self .assertIsNone (response .topk_prompt_logprobs )
1379+ response = await engine .sample .remote (
1380+ prompt = prompt ,
1381+ num_samples = 1 ,
1382+ sampling_params = types .SamplingParams (max_tokens = 1 ),
1383+ include_prompt_logprobs = True ,
1384+ lora_request = lora_2 ,
1385+ )
1386+ self .assertEqual (len (response .sequences ), 1 )
1387+ self .assertEqual (response .sequences [0 ].stop_reason , "length" )
1388+ self .assertEqual (len (prompt .to_ints ()), len (response .prompt_logprobs ))
1389+ self .assertIsNone (response .topk_prompt_logprobs )
1390+ await engine .remove_lora_adapter .remote (lora_id = 1 )
1391+ await engine .remove_lora_adapter .remote (lora_id = 2 )
1392+ ids = await engine .list_lora_adapters .remote ()
1393+ self .assertEqual (ids , [])
1394+
1395+
1396+ def _create_adapter (model_path : str , lora_path : str , name : str ):
1397+ from peft import LoraConfig , get_peft_model
1398+ from transformers import AutoModelForCausalLM
1399+
1400+ model = AutoModelForCausalLM .from_pretrained (
1401+ model_path ,
1402+ device_map = "cpu" ,
1403+ )
1404+ lora_config = LoraConfig (
1405+ r = 8 ,
1406+ lora_alpha = 8 ,
1407+ target_modules = ["gate_proj" , "up_proj" , "down_proj" ],
1408+ lora_dropout = 0.1 ,
1409+ )
1410+ lora_model = get_peft_model (model , lora_config , adapter_name = name )
1411+ lora_model .save_pretrained (lora_path )
0 commit comments