1- import time
1+ import argparse
22
33from tensorrt_llm import LLM
44from tensorrt_llm .llmapi import KvCacheConfig
55
66
7- def main ():
7+ def main (args ):
88
99 prompt_a = (
1010 "Given the following question and four candidate answers (A, B, C and D), choose the best answer."
@@ -20,79 +20,52 @@ def main():
2020
2121 kv_cache_max_tokens = 256
2222 kv_cache_page_size = 16
23+ kv_cache_host_size = 1024 ** 3 if args .enable_offloading else 0
2324
24- # Offloading Off
25- print ("\n ====== Offloading Off ====== \n " )
2625 llm = LLM (model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
2726 max_batch_size = max_batch_size ,
2827 max_seq_len = max_seq_len ,
2928 kv_cache_config = KvCacheConfig (enable_block_reuse = True ,
3029 max_tokens = kv_cache_max_tokens ,
3130 tokens_per_block = kv_cache_page_size ,
32- host_cache_size = 0 ))
31+ host_cache_size = kv_cache_host_size ))
3332 # prompt_a occupies kv cache pool
3433 output_a = llm .generate (prompt_a )
3534 print (
3635 f"Prompt: { output_a .prompt !r} , Generated text: { output_a .outputs [0 ].text !r} "
3736 )
38-
39- # since max_batch_size=1, prompt_b clears and update kv cache
40- output_b = llm .generate (prompt_b )
41- print (
42- f"Prompt: { output_b .prompt !r} , Generated text: { output_b .outputs [0 ].text !r} "
43- )
44-
45- # prompt_a clears and update kv cache again
46- # no kv cache reuse happens
47- output_a = llm .generate (prompt_a )
48- print (
49- f"Prompt: { output_a .prompt !r} , Generated text: { output_a .outputs [0 ].text !r} "
50- )
51-
52- # prompt_b clears and update kv cache again
53- # no kv cache reuse happens
54- output_b = llm .generate (prompt_b )
55- print (
56- f"Prompt: { output_b .prompt !r} , Generated text: { output_b .outputs [0 ].text !r} "
57- )
58-
59- llm .shutdown ()
60- time .sleep (5 )
61-
62- # Offloading On
63- print ("\n ====== Offloading On ====== \n " )
64- llm = LLM (model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
65- max_batch_size = max_batch_size ,
66- max_seq_len = max_seq_len ,
67- kv_cache_config = KvCacheConfig (enable_block_reuse = True ,
68- max_tokens = kv_cache_max_tokens ,
69- tokens_per_block = kv_cache_page_size ,
70- host_cache_size = 1024 ** 3 ))
71- # prompt_a occupies kv cache pool
72- output_a = llm .generate (prompt_a )
73- print (
74- f"Prompt: { output_a .prompt !r} , Generated text: { output_a .outputs [0 ].text !r} "
75- )
76-
77- # since max_batch_size=1, and offloading is enabled,
78- # kv cache of prompt_a will be offloaded to host memory.
79- # kv cache of prompt_b keeps in device memory.
37+ '''
38+ since max_batch_size=1,
39+ if enable_offloading=False:
40+ prompt_b clears and update kv cache
41+ if enable_offloading=True:
42+ kv cache of prompt_a will be offloaded to host memory.
43+ kv cache of prompt_b keeps in device memory.
44+ '''
8045 output_b = llm .generate (prompt_b )
8146 print (
8247 f"Prompt: { output_b .prompt !r} , Generated text: { output_b .outputs [0 ].text !r} "
8348 )
84-
85- # kv cache of prompt_a will be onboarded to device memory,
86- # kv cache of prompt_b will be offloaded to host memory.
87- # kv cache of prompt_a will be reused.
49+ '''
50+ if not enable_offloading:
51+ prompt_a clears and update kv cache again, no kv cache reuse happens
52+ else:
53+ kv cache of prompt_a will be onboarded to device memory, and be reused.
54+ kv cache of prompt_b will be offloaded to host memory.
55+ '''
8856 output_a = llm .generate (prompt_a )
8957 print (
9058 f"Prompt: { output_a .prompt !r} , Generated text: { output_a .outputs [0 ].text !r} "
9159 )
92-
93- # kv cache of prompt_b will be onboarded to device memory,
94- # kv cache of prompt_a will be offloaded to host memory.
95- # kv cache of prompt_b will be reused.
60+ '''
61+ if not enable_offloading:
62+ prompt_b clears and update kv cache again, no kv cache reuse happens
63+ else:
64+ kv cache of prompt_b will be onboarded to device memory, and be reused.
65+ kv cache of prompt_a will be offloaded to host memory.
66+ '''
67+ #
68+ #
9669 output_b = llm .generate (prompt_b )
9770 print (
9871 f"Prompt: { output_b .prompt !r} , Generated text: { output_b .outputs [0 ].text !r} "
@@ -102,4 +75,9 @@ def main():
10275
10376
10477if __name__ == "__main__" :
105- main ()
78+ parser = argparse .ArgumentParser ()
79+ parser .add_argument ('--enable_offloading' ,
80+ default = False ,
81+ action = 'store_true' )
82+ args = parser .parse_args ()
83+ main (args )
0 commit comments