Skip to content

Commit 750d15b

Browse files
authored
[https://nvbugs/5503529][fix] Change test_llmapi_example_multilora to get adapters path from cmd line to avoid downloading from HF (#7740)
Signed-off-by: Amit Zuker <[email protected]>
1 parent 6eef192 commit 750d15b

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

examples/llm-api/llm_multilora.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
### :section Customization
22
### :title Generate text with multiple LoRA adapters
33
### :order 5
4+
5+
import argparse
6+
from typing import Optional
7+
48
from huggingface_hub import snapshot_download
59

610
from tensorrt_llm import LLM
711
from tensorrt_llm.executor import LoRARequest
812
from tensorrt_llm.lora_helper import LoraConfig
913

1014

11-
def main():
15+
def main(chatbot_lora_dir: Optional[str], mental_health_lora_dir: Optional[str],
16+
tarot_lora_dir: Optional[str]):
1217

13-
# Download the LoRA adapters from huggingface hub.
14-
lora_dir1 = snapshot_download(repo_id="snshrivas10/sft-tiny-chatbot")
15-
lora_dir2 = snapshot_download(
16-
repo_id="givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
17-
lora_dir3 = snapshot_download(repo_id="barissglc/tinyllama-tarot-v1")
18+
# Download the LoRA adapters from huggingface hub, if not provided via command line args.
19+
if chatbot_lora_dir is None:
20+
chatbot_lora_dir = snapshot_download(
21+
repo_id="snshrivas10/sft-tiny-chatbot")
22+
if mental_health_lora_dir is None:
23+
mental_health_lora_dir = snapshot_download(
24+
repo_id=
25+
"givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
26+
if tarot_lora_dir is None:
27+
tarot_lora_dir = snapshot_download(
28+
repo_id="barissglc/tinyllama-tarot-v1")
1829

1930
# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
2031
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
21-
lora_config = LoraConfig(lora_dir=[lora_dir1],
32+
lora_config = LoraConfig(lora_dir=[chatbot_lora_dir],
2233
max_lora_rank=64,
2334
max_loras=3,
2435
max_cpu_loras=3)
@@ -39,10 +50,11 @@ def main():
3950
for output in llm.generate(prompts,
4051
lora_request=[
4152
None,
42-
LoRARequest("chatbot", 1, lora_dir1), None,
43-
LoRARequest("mental-health", 2, lora_dir2),
53+
LoRARequest("chatbot", 1, chatbot_lora_dir),
4454
None,
45-
LoRARequest("tarot", 3, lora_dir3)
55+
LoRARequest("mental-health", 2,
56+
mental_health_lora_dir), None,
57+
LoRARequest("tarot", 3, tarot_lora_dir)
4658
]):
4759
prompt = output.prompt
4860
generated_text = output.outputs[0].text
@@ -58,4 +70,20 @@ def main():
5870

5971

6072
if __name__ == '__main__':
61-
main()
73+
parser = argparse.ArgumentParser(
74+
description="Generate text with multiple LoRA adapters")
75+
parser.add_argument('--chatbot_lora_dir',
76+
type=str,
77+
default=None,
78+
help='Path to the chatbot LoRA directory')
79+
parser.add_argument('--mental_health_lora_dir',
80+
type=str,
81+
default=None,
82+
help='Path to the mental health LoRA directory')
83+
parser.add_argument('--tarot_lora_dir',
84+
type=str,
85+
default=None,
86+
help='Path to the tarot LoRA directory')
87+
args = parser.parse_args()
88+
main(args.chatbot_lora_dir, args.mental_health_lora_dir,
89+
args.tarot_lora_dir)

tests/integration/defs/llmapi/test_llm_examples.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,16 @@ def test_llmapi_example_inference_async_streaming(llm_root, engine_dir,
110110

111111

112112
def test_llmapi_example_multilora(llm_root, engine_dir, llm_venv):
113-
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py")
113+
cmd_line_args = [
114+
"--chatbot_lora_dir",
115+
f"{llm_models_root()}/llama-models-v2/sft-tiny-chatbot",
116+
"--mental_health_lora_dir",
117+
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational",
118+
"--tarot_lora_dir",
119+
f"{llm_models_root()}/llama-models-v2/tinyllama-tarot-v1"
120+
]
121+
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py",
122+
*cmd_line_args)
114123

115124

116125
def test_llmapi_example_guided_decoding(llm_root, engine_dir, llm_venv):

0 commit comments

Comments
 (0)