2626# Global torch config, set the torch compile cache to fix up to llama 405B
2727torch ._dynamo .config .cache_size_limit = 20
2828
29+ # simple string, TRT-LLM style text-only prompt or full-scale HF message template
30+ PromptInput = Union [str , Dict , List [Dict ]]
31+
2932
3033class PromptConfig (BaseModel ):
3134 """Prompt configuration.
@@ -35,17 +38,27 @@ class PromptConfig(BaseModel):
3538 """
3639
3740 batch_size : int = Field (default = 2 , description = "Number of queries" )
38- queries : Union [str , List [str ]] = Field (
41+ queries : Union [PromptInput , List [PromptInput ]] = Field (
3942 default_factory = lambda : [
43+ # OPTION 1: simple text prompt
4044 "How big is the universe? " ,
41- "In simple words and in a single sentence, explain the concept of gravity: " ,
42- "How to fix slicing in golf? " ,
43- "Where is the capital of Iceland? " ,
44- "How big is the universe? " ,
45- "In simple words and in a single sentence, explain the concept of gravity: " ,
46- "How to fix slicing in golf? " ,
47- "Where is the capital of Iceland? " ,
48- ]
45+ # OPTION 2: wrapped text prompt for TRT-LLM
46+ {"prompt" : "In simple words and a single sentence, explain the concept of gravity: " },
47+ # OPTION 3: a full-scale HF message template (this one works for text-only models!)
48+ # Learn more about chat templates: https://huggingface.co/docs/transformers/en/chat_templating
49+ # and multi-modal templates: https://huggingface.co/docs/transformers/en/chat_templating_multimodal
50+ [
51+ {
52+ "role" : "user" ,
53+ "content" : "How to fix slicing in golf?" ,
54+ }
55+ ],
56+ # More prompts...
57+ {"prompt" : "Where is the capital of Iceland? " },
58+ ],
59+ description = "Example queries to prompt the model with. We support both TRT-LLM text-only "
60+ "queries via the 'prompt' key and full-scale HF message template called via "
61+ "apply_chat_template." ,
4962 )
5063 sp_kwargs : Dict [str , Any ] = Field (
5164 default_factory = lambda : {"max_tokens" : 100 , "top_k" : 200 , "temperature" : 1.0 },
@@ -59,10 +72,28 @@ def model_post_init(self, __context: Any):
5972 NOTE (lucaslie): has to be done with model_post_init to ensure it's always run. field
6073 validators are only run if a value is provided.
6174 """
62- queries = [ self .queries ] if isinstance (self .queries , str ) else self .queries
75+ queries = self .queries if isinstance (self .queries , list ) else [ self .queries ]
6376 batch_size = self .batch_size
6477 queries = queries * (batch_size // len (queries ) + 1 )
65- self .queries = queries [:batch_size ]
78+ queries = queries [:batch_size ]
79+
80+ # now let's standardize the queries for the LLM api to understand them
81+ queries_processed = []
82+ for query in queries :
83+ if isinstance (query , str ):
84+ queries_processed .append ({"prompt" : query })
85+ elif isinstance (query , dict ):
86+ queries_processed .append (query )
87+ elif isinstance (query , list ):
88+ queries_processed .append (
89+ {
90+ "prompt" : "Fake prompt. Check out messages field for the HF chat template." ,
91+ "messages" : query , # contains the actual HF chat template
92+ }
93+ )
94+ else :
95+ raise ValueError (f"Invalid query type: { type (query )} " )
96+ self .queries = queries_processed
6697
6798 @field_validator ("sp_kwargs" , mode = "after" )
6899 @classmethod
@@ -237,56 +268,13 @@ def main(config: Optional[ExperimentConfig] = None):
237268
238269 llm = build_llm_from_config (config )
239270
240- # just run config.prompt.queries with our special token sequence including special image tokens
241- # fmt: off
242- input_ids = [[
243- 200000 , 200005 , 1556 , 200006 , 368 , 200080 , 200090 , 200092 , 200092 ,
244- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
245- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
246- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
247- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
248- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
249- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
250- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
251- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
252- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
253- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
254- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
255- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
256- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
257- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
258- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
259- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200081 , 200080 ,
260- 200090 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
261- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
262- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
263- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
264- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
265- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
266- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
267- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
268- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
269- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
270- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
271- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
272- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
273- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
274- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
275- 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 , 200092 ,
276- 200092 , 200081 , 51212 , 1780 , 650 , 2556 , 310 , 290 , 1472 ,
277- 8392 , 341 , 1357 , 13492 , 26 , 200008 , 200005 , 140680 , 200006 ,
278- 368
279- ] for _ in range (2 )]
280- # fmt: on
281-
282271 # prompt the model and print its output
283272 ad_logger .info ("Running example prompts..." )
284273
285274 # now let's try piping through multimodal data
286275
287276 outs = llm .generate (
288- input_ids ,
289- # config.prompt.queries,
277+ config .prompt .queries ,
290278 sampling_params = SamplingParams (** config .prompt .sp_kwargs ),
291279 )
292280 results = {"prompts_and_outputs" : print_outputs (outs )}
0 commit comments