Skip to content

Commit e9f9780

Browse files
committed
Update on "Add utils to replace torchtune SDPA with ET Custom SDPA"
Differential Revision: [D67878161](https://our.internmc.facebook.com/intern/diff/D67878161) [ghstack-poisoned]
2 parents 158779f + 8c087bf commit e9f9780

File tree

1 file changed

+4
-4
lines changed
  • examples/models/llama/source_transformation

1 file changed

+4
-4
lines changed

examples/models/llama/source_transformation/sdpa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Example script for exporting Llama2 to flatbuffer
1010

1111
import math
12-
from typing import Tuple, Union, Optional
12+
from typing import Optional, Tuple, Union
1313

1414
import torch
1515

@@ -24,7 +24,7 @@ def __init__(
2424
self,
2525
kv_cache: Optional[Union[KVCache, QuantizedKVCache]] = None,
2626
dim: int = -1,
27-
is_causal = True,
27+
is_causal=True,
2828
):
2929
super().__init__()
3030
# Custom op only supports float32 currently. Converting to/from float32 is
@@ -48,8 +48,8 @@ def forward(
4848
k: torch.Tensor,
4949
v: torch.Tensor,
5050
bsz,
51-
seqlen = None,
52-
mask = None,
51+
seqlen=None,
52+
mask=None,
5353
):
5454
# Custom op only supports float32 currently. Converting to/from float32 is
5555
# faster than not having the op.

0 commit comments

Comments
 (0)