File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change 99
1010import torch
1111import torchtune .modules .attention as TorchTuneAttention
12+ from executorch .examples .models .llama .source_transformation .sdpa import SDPACustom
1213from executorch .extension .llm .modules .kv_cache import KVCache as InferenceKVCache
1314from torch import nn
1415from torchtune .modules .attention_utils import _MaskType , _sdpa_or_flex_attention
1516from torchtune .modules .kv_cache import KVCache
16- from executorch .examples .models .llama .source_transformation .sdpa import SDPACustom
1717
1818logger = logging .getLogger (__name__ )
1919
@@ -367,7 +367,6 @@ def forward(
367367 k = k .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
368368 v = v .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
369369
370-
371370 output = self ._attention_fn (
372371 q ,
373372 k ,
@@ -431,5 +430,6 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
431430
432431def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
433432 from executorch .extension .llm .custom_ops import custom_ops
433+
434434 _replace_sdpa_with_custom_op (module )
435435 return module
You can’t perform that action at this time.
0 commit comments