Skip to content

Commit 7d6f521

Browse files
committed
Update base for Update on "Add test case to export, quantize and lower vision encoder model for ET"
Differential Revision: [D67878162](https://our.internmc.facebook.com/intern/diff/D67878162) [ghstack-poisoned]
1 parent 0bbe0b2 commit 7d6f521

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

extension/llm/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import torch
1111
import torchtune.modules.attention as TorchTuneAttention
12+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1213
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
1314
from torch import nn
1415
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1516
from torchtune.modules.kv_cache import KVCache
16-
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -367,7 +367,6 @@ def forward(
367367
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
368368
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
369369

370-
371370
output = self._attention_fn(
372371
q,
373372
k,
@@ -431,5 +430,6 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
431430

432431
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
433432
from executorch.extension.llm.custom_ops import custom_ops
433+
434434
_replace_sdpa_with_custom_op(module)
435435
return module

0 commit comments

Comments
 (0)