Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions optimum/executorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ExecuTorchModelForMaskedLM",
"ExecuTorchModelForSeq2SeqLM",
"ExecuTorchModelForSpeechSeq2Seq",
"ExecuTorchModelForMultiModalToText",
],
}

Expand All @@ -34,6 +35,8 @@
ExecuTorchModelForMaskedLM,
ExecuTorchModelForSeq2SeqLM,
ExecuTorchModelForSpeechSeq2Seq,
ExecuTorchModelForImageTextToTextCausalLM,
ExecuTorchModelForMultiModalToText,
)
else:
import sys
Expand Down
28 changes: 18 additions & 10 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def custom_sdpa_with_start_pos_forward(

# Ignore the causal flag from kwargs but use the one in module
kwargs.pop("is_causal", None)
assert module.is_causal, "Current variant supports only causal attention"
# assert module.is_causal, "Current variant supports only causal attention"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this supports non causal case


is_causal = module.is_causal
if kwargs.get("is_sliding", False):
Expand All @@ -56,13 +56,16 @@ def custom_sdpa_with_start_pos_forward(
start_pos = 0
else:
attn_mask = None
# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, max_seq_len)
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
if is_causal:
# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
else:
start_pos = 0

output = torch.ops.llama.custom_sdpa(
query,
Expand All @@ -81,14 +84,19 @@ def get_custom_sdpa_for_ring_kv_cache(
exportable_module: torch.nn.Module,
) -> Callable:
# lazy importing to avoid version dependent class definition
from executorch import version
# try:
# from executorch import __version__ as version
# except ImportError:
# # Fallback if version is not available
# version = None

try:
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
CustomRingKVCache,
)
except ImportError:
raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.")
# raise ImportError(f"CustomRingKVCache not available in version {version.__version__} of ExecuTorch.")
print()

def _custom_sdpa_for_ring_kv_cache(
module: torch.nn.Module,
Expand Down
Loading
Loading