Skip to content

Commit 531d151

Browse files
authored
[docs] no hard-coding cuda (#36043)
make device-agnostic
1 parent 7399f80 commit 531d151

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

docs/source/en/kv_cache.md

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ More concretely, key-value cache acts as a memory bank for these generative mode
5757
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
5858

5959
>>> model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
60-
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
60+
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
6161
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
6262

6363
>>> past_key_values = DynamicCache()
6464
>>> messages = [{"role": "user", "content": "Hello, what's your name."}]
65-
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")
65+
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
6666

6767
>>> generated_ids = inputs.input_ids
68-
>>> cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
68+
>>> cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
6969
>>> max_new_tokens = 10
7070

7171
>>> for _ in range(max_new_tokens):
@@ -139,7 +139,7 @@ Cache quantization can be detrimental in terms of latency if the context length
139139
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
140140

141141
>>> tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
142-
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16).to("cuda:0")
142+
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
143143
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
144144

145145
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
@@ -168,7 +168,7 @@ Use `cache_implementation="offloaded_static"` for an offloaded static cache (see
168168
>>> ckpt = "microsoft/Phi-3-mini-4k-instruct"
169169

170170
>>> tokenizer = AutoTokenizer.from_pretrained(ckpt)
171-
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
171+
>>> model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, device_map="auto")
172172
>>> inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
173173

174174
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
@@ -278,7 +278,7 @@ Note that you can use this cache only for models that support sliding window, e.
278278
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
279279

280280
>>> tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
281-
>>> model = AutoModelForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to("cuda:0")
281+
>>> model = AutoModelForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16, device_map="auto")
282282
>>> inputs = tokenizer("Yesterday I was on a rock concert and.", return_tensors="pt").to(model.device)
283283

284284
>>> # can be used by passing in cache implementation
@@ -298,7 +298,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac
298298
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
299299

300300
>>> tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
301-
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16).to("cuda:0")
301+
>>> model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
302302
>>> inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)
303303

304304
>>> # get our cache, specify number of sink tokens and window size
@@ -377,25 +377,27 @@ Sometimes you would want to first fill-in cache object with key/values for certa
377377
>>> import copy
378378
>>> import torch
379379
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
380+
>>> from accelerate.test_utils.testing import get_backend
380381

382+
>>> DEVICE, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
381383
>>> model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
382-
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
384+
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=DEVICE)
383385
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
384386

385387
>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
386388
>>> # You can also init a DynamicCache, if that suits you better
387-
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
389+
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device=DEVICE, dtype=torch.bfloat16)
388390

389391
>>> INITIAL_PROMPT = "You are a helpful assistant. "
390-
>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
392+
>>> inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(DEVICE)
391393
>>> # This is the common prompt cached, we need to run forward without grad to be abel to copy
392394
>>> with torch.no_grad():
393395
... prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
394396

395397
>>> prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
396398
>>> responses = []
397399
>>> for prompt in prompts:
398-
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
400+
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(DEVICE)
399401
... past_key_values = copy.deepcopy(prompt_cache)
400402
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
401403
... response = tokenizer.batch_decode(outputs)[0]

0 commit comments

Comments
 (0)