@@ -57,15 +57,15 @@ More concretely, key-value cache acts as a memory bank for these generative mode
57
57
>> > from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
58
58
59
59
>> > model_id = " TinyLlama/TinyLlama-1.1B-Chat-v1.0"
60
- >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " cuda:0 " )
60
+ >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " auto " )
61
61
>> > tokenizer = AutoTokenizer.from_pretrained(model_id)
62
62
63
63
>> > past_key_values = DynamicCache()
64
64
>> > messages = [{" role" : " user" , " content" : " Hello, what's your name." }]
65
- >> > inputs = tokenizer.apply_chat_template(messages, add_generation_prompt = True , return_tensors = " pt" , return_dict = True ).to(" cuda:0 " )
65
+ >> > inputs = tokenizer.apply_chat_template(messages, add_generation_prompt = True , return_tensors = " pt" , return_dict = True ).to(model.device )
66
66
67
67
>> > generated_ids = inputs.input_ids
68
- >> > cache_position = torch.arange(inputs.input_ids.shape[1 ], dtype = torch.int64, device = " cuda:0 " )
68
+ >> > cache_position = torch.arange(inputs.input_ids.shape[1 ], dtype = torch.int64, device = model.device )
69
69
>> > max_new_tokens = 10
70
70
71
71
>> > for _ in range (max_new_tokens):
@@ -139,7 +139,7 @@ Cache quantization can be detrimental in terms of latency if the context length
139
139
>> > from transformers import AutoTokenizer, AutoModelForCausalLM
140
140
141
141
>> > tokenizer = AutoTokenizer.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
142
- >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16).to( " cuda:0 " )
142
+ >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16, device_map = " auto " )
143
143
>> > inputs = tokenizer(" I like rock music because" , return_tensors = " pt" ).to(model.device)
144
144
145
145
>> > out = model.generate(** inputs, do_sample = False , max_new_tokens = 20 , cache_implementation = " quantized" , cache_config = {" nbits" : 4 , " backend" : " quanto" })
@@ -168,7 +168,7 @@ Use `cache_implementation="offloaded_static"` for an offloaded static cache (see
168
168
>> > ckpt = " microsoft/Phi-3-mini-4k-instruct"
169
169
170
170
>> > tokenizer = AutoTokenizer.from_pretrained(ckpt)
171
- >> > model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype = torch.float16).to( " cuda:0 " )
171
+ >> > model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype = torch.float16, device_map = " auto " )
172
172
>> > inputs = tokenizer(" Fun fact: The shortest" , return_tensors = " pt" ).to(model.device)
173
173
174
174
>> > out = model.generate(** inputs, do_sample = False , max_new_tokens = 23 , cache_implementation = " offloaded" )
@@ -278,7 +278,7 @@ Note that you can use this cache only for models that support sliding window, e.
278
278
>> > from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
279
279
280
280
>> > tokenizer = AutoTokenizer.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" )
281
- >> > model = AutoModelForCausalLM.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" , torch_dtype = torch.float16).to( " cuda:0 " )
281
+ >> > model = AutoModelForCausalLM.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" , torch_dtype = torch.float16, device_map = " auto " )
282
282
>> > inputs = tokenizer(" Yesterday I was on a rock concert and." , return_tensors = " pt" ).to(model.device)
283
283
284
284
>> > # can be used by passing in cache implementation
@@ -298,7 +298,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac
298
298
>> > from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
299
299
300
300
>> > tokenizer = AutoTokenizer.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
301
- >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16).to( " cuda:0 " )
301
+ >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16, device_map = " auto " )
302
302
>> > inputs = tokenizer(" This is a long story about unicorns, fairies and magic." , return_tensors = " pt" ).to(model.device)
303
303
304
304
>> > # get our cache, specify number of sink tokens and window size
@@ -377,25 +377,27 @@ Sometimes you would want to first fill-in cache object with key/values for certa
377
377
>> > import copy
378
378
>> > import torch
379
379
>> > from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
380
+ >> > from accelerate.test_utils.testing import get_backend
380
381
382
+ >> > DEVICE , _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
381
383
>> > model_id = " TinyLlama/TinyLlama-1.1B-Chat-v1.0"
382
- >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " cuda " )
384
+ >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = DEVICE )
383
385
>> > tokenizer = AutoTokenizer.from_pretrained(model_id)
384
386
385
387
>> > # Init StaticCache with big enough max-length (1024 tokens for the below example)
386
388
>> > # You can also init a DynamicCache, if that suits you better
387
- >> > prompt_cache = StaticCache(config = model.config, max_batch_size = 1 , max_cache_len = 1024 , device = " cuda " , dtype = torch.bfloat16)
389
+ >> > prompt_cache = StaticCache(config = model.config, max_batch_size = 1 , max_cache_len = 1024 , device = DEVICE , dtype = torch.bfloat16)
388
390
389
391
>> > INITIAL_PROMPT = " You are a helpful assistant. "
390
- >> > inputs_initial_prompt = tokenizer(INITIAL_PROMPT , return_tensors = " pt" ).to(" cuda " )
392
+ >> > inputs_initial_prompt = tokenizer(INITIAL_PROMPT , return_tensors = " pt" ).to(DEVICE )
391
393
>> > # This is the common prompt cached, we need to run forward without grad to be abel to copy
392
394
>> > with torch.no_grad():
393
395
... prompt_cache = model(** inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
394
396
395
397
>> > prompts = [" Help me to write a blogpost about travelling." , " What is the capital of France?" ]
396
398
>> > responses = []
397
399
>> > for prompt in prompts:
398
- ... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors = " pt" ).to(" cuda " )
400
+ ... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors = " pt" ).to(DEVICE )
399
401
... past_key_values = copy.deepcopy(prompt_cache)
400
402
... outputs = model.generate(** new_inputs, past_key_values = past_key_values,max_new_tokens = 20 )
401
403
... response = tokenizer.batch_decode(outputs)[0 ]
0 commit comments