Skip to content

Commit bceca59

Browse files
jla524AlanPonnachan
authored andcommitted
Update docs for sdpa_kernel (huggingface#35410)
update: sdp_kernel -> sdpa_kernel
1 parent c9fe423 commit bceca59

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

docs/source/en/perf_infer_gpu_one.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,11 @@ In that case, you should see a warning message and we will fall back to the (slo
332332

333333
</Tip>
334334

335-
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
335+
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.nn.attention.sdpa_kernel`](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) as a context manager:
336336

337337
```diff
338338
import torch
339+
+ from torch.nn.attention import SDPBackend, sdpa_kernel
339340
from transformers import AutoModelForCausalLM, AutoTokenizer
340341

341342
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
@@ -344,7 +345,7 @@ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=to
344345
input_text = "Hello my dog is cute and"
345346
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
346347

347-
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
348+
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
348349
outputs = model.generate(**inputs)
349350

350351
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
@@ -518,6 +519,7 @@ It is often possible to combine several of the optimization techniques described
518519

519520
```py
520521
import torch
522+
from torch.nn.attention import SDPBackend, sdpa_kernel
521523
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
522524

523525
# load model in 4-bit
@@ -536,7 +538,7 @@ input_text = "Hello my dog is cute and"
536538
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
537539

538540
# enable FlashAttention
539-
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
541+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
540542
outputs = model.generate(**inputs)
541543

542544
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

0 commit comments

Comments
 (0)