Skip to content

Commit 158779f

Browse files
committed
Add utils to replace torchtune SDPA with ET Custom SDPA
[ghstack-poisoned]
1 parent 3979fc8 commit 158779f

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

extension/llm/modules/attention.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import nn
1414
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
1515
from torchtune.modules.kv_cache import KVCache
16+
from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -310,7 +311,9 @@ def false_fn(y):
310311
self.kv_cache.v_cache.copy_(v)
311312
self.kv_cache.cache_pos.copy_(cache_pos)
312313

313-
output = self._sdpa(q, k, v, b, s_x, mask=mask)
314+
if input_pos is None:
315+
input_pos = torch.tensor(0)
316+
output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask)
314317
return self.output_proj(output)
315318

316319

@@ -364,6 +367,7 @@ def forward(
364367
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
365368
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
366369

370+
367371
output = self._attention_fn(
368372
q,
369373
k,
@@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
411415
"""
412416
_replace_mha_with_inference_mha(module)
413417
return module
418+
419+
420+
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
421+
for name, child in module.named_children():
422+
if isinstance(child, SDPA):
423+
setattr(
424+
module,
425+
name,
426+
SDPACustom(is_causal=child.is_causal),
427+
)
428+
else:
429+
_replace_sdpa_with_custom_op(child)
430+
431+
432+
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
433+
from executorch.extension.llm.custom_ops import custom_ops
434+
_replace_sdpa_with_custom_op(module)
435+
return module

0 commit comments

Comments
 (0)