11### :section Customization
22### :title Generate text with multiple LoRA adapters
33### :order 5
4+
5+ import argparse
6+ from typing import Optional
7+
48from huggingface_hub import snapshot_download
59
610from tensorrt_llm import LLM
711from tensorrt_llm .executor import LoRARequest
812from tensorrt_llm .lora_helper import LoraConfig
913
1014
11- def main ():
15+ def main (chatbot_lora_dir : Optional [str ], mental_health_lora_dir : Optional [str ],
16+ tarot_lora_dir : Optional [str ]):
1217
13- # Download the LoRA adapters from huggingface hub.
14- lora_dir1 = snapshot_download (repo_id = "snshrivas10/sft-tiny-chatbot" )
15- lora_dir2 = snapshot_download (
16- repo_id = "givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational" )
17- lora_dir3 = snapshot_download (repo_id = "barissglc/tinyllama-tarot-v1" )
18+ # Download the LoRA adapters from huggingface hub, if not provided via command line args.
19+ if chatbot_lora_dir is None :
20+ chatbot_lora_dir = snapshot_download (
21+ repo_id = "snshrivas10/sft-tiny-chatbot" )
22+ if mental_health_lora_dir is None :
23+ mental_health_lora_dir = snapshot_download (
24+ repo_id =
25+ "givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational" )
26+ if tarot_lora_dir is None :
27+ tarot_lora_dir = snapshot_download (
28+ repo_id = "barissglc/tinyllama-tarot-v1" )
1829
1930 # Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
2031 # This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
21- lora_config = LoraConfig (lora_dir = [lora_dir1 ],
32+ lora_config = LoraConfig (lora_dir = [chatbot_lora_dir ],
2233 max_lora_rank = 64 ,
2334 max_loras = 3 ,
2435 max_cpu_loras = 3 )
@@ -39,10 +50,11 @@ def main():
3950 for output in llm .generate (prompts ,
4051 lora_request = [
4152 None ,
42- LoRARequest ("chatbot" , 1 , lora_dir1 ), None ,
43- LoRARequest ("mental-health" , 2 , lora_dir2 ),
53+ LoRARequest ("chatbot" , 1 , chatbot_lora_dir ),
4454 None ,
45- LoRARequest ("tarot" , 3 , lora_dir3 )
55+ LoRARequest ("mental-health" , 2 ,
56+ mental_health_lora_dir ), None ,
57+ LoRARequest ("tarot" , 3 , tarot_lora_dir )
4658 ]):
4759 prompt = output .prompt
4860 generated_text = output .outputs [0 ].text
@@ -58,4 +70,20 @@ def main():
5870
5971
6072if __name__ == '__main__' :
61- main ()
73+ parser = argparse .ArgumentParser (
74+ description = "Generate text with multiple LoRA adapters" )
75+ parser .add_argument ('--chatbot_lora_dir' ,
76+ type = str ,
77+ default = None ,
78+ help = 'Path to the chatbot LoRA directory' )
79+ parser .add_argument ('--mental_health_lora_dir' ,
80+ type = str ,
81+ default = None ,
82+ help = 'Path to the mental health LoRA directory' )
83+ parser .add_argument ('--tarot_lora_dir' ,
84+ type = str ,
85+ default = None ,
86+ help = 'Path to the tarot LoRA directory' )
87+ args = parser .parse_args ()
88+ main (args .chatbot_lora_dir , args .mental_health_lora_dir ,
89+ args .tarot_lora_dir )
0 commit comments