|
13 | 13 | from torch import nn |
14 | 14 | from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention |
15 | 15 | from torchtune.modules.kv_cache import KVCache |
| 16 | +from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom |
16 | 17 |
|
17 | 18 | logger = logging.getLogger(__name__) |
18 | 19 |
|
@@ -310,7 +311,9 @@ def false_fn(y): |
310 | 311 | self.kv_cache.v_cache.copy_(v) |
311 | 312 | self.kv_cache.cache_pos.copy_(cache_pos) |
312 | 313 |
|
313 | | - output = self._sdpa(q, k, v, b, s_x, mask=mask) |
| 314 | + if input_pos is None: |
| 315 | + input_pos = torch.tensor(0) |
| 316 | + output = self._sdpa(input_pos, q, k, v, b, s_x, mask=mask) |
314 | 317 | return self.output_proj(output) |
315 | 318 |
|
316 | 319 |
|
@@ -364,6 +367,7 @@ def forward( |
364 | 367 | k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) |
365 | 368 | v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) |
366 | 369 |
|
| 370 | + |
367 | 371 | output = self._attention_fn( |
368 | 372 | q, |
369 | 373 | k, |
@@ -411,3 +415,21 @@ def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module: |
411 | 415 | """ |
412 | 416 | _replace_mha_with_inference_mha(module) |
413 | 417 | return module |
| 418 | + |
| 419 | + |
| 420 | +def _replace_sdpa_with_custom_op(module: torch.nn.Module): |
| 421 | + for name, child in module.named_children(): |
| 422 | + if isinstance(child, SDPA): |
| 423 | + setattr( |
| 424 | + module, |
| 425 | + name, |
| 426 | + SDPACustom(is_causal=child.is_causal), |
| 427 | + ) |
| 428 | + else: |
| 429 | + _replace_sdpa_with_custom_op(child) |
| 430 | + |
| 431 | + |
| 432 | +def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: |
| 433 | + from executorch.extension.llm.custom_ops import custom_ops |
| 434 | + _replace_sdpa_with_custom_op(module) |
| 435 | + return module |
0 commit comments