From 1273002ca491f8cef8129135db6422fab2f24f8e Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Fri, 31 Jan 2025 17:58:34 -0800 Subject: [PATCH] Single location to update optional args for all attentions Summary: Incremental improvement on the UX. When users need to add an optional argument for a new attention, there's this centralized location, `ForwardOptions` to be updated. Differential Revision: D68988021 --- examples/models/llama/attention.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index c6f01dbadce..ec55f2f1ee0 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type, TypedDict import torch import torch.nn as nn @@ -8,6 +8,15 @@ from executorch.examples.models.llama.rope import Rope +class ForwardOptions(TypedDict, total=False): + """Optional parameters for `Attention.forward` (compative with Python 3.10 and plus).""" + + mask: Optional[torch.Tensor] + input_pos: Optional[torch.Tensor] + in_cache_state: Optional[Any] + out_cache_state: Optional[Any] + + class Attention(nn.Module, ABC): """Abstract base class for attention mechanisms with unified interface.""" @@ -17,19 +26,14 @@ def forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - in_cache_state: Optional[Any] = None, - out_cache_state: Optional[Any] = None, + **kwargs: ForwardOptions, ) -> Tuple[torch.Tensor, Optional[Any]]: """Forward pass for attention mechanism. Args: x: Input tensor of shape (batch_size, seq_len, dim) freqs_cos, freqs_sin: Rotary position embedding frequencies - mask: Optional attention mask - input_pos: Positions for KV cache updates - in_cache_state/out_cache_state: Cache states + ForwardOptions: grouped optional args Returns: Tuple of (output tensor, updated cache state) @@ -209,11 +213,9 @@ def forward( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - in_cache_state: Optional[Any] = None, - out_cache_state: Optional[Any] = None, + **kwargs: ForwardOptions, ) -> Tuple[torch.Tensor, Optional[Any]]: + input_pos = kwargs.get("input_pos") bsz, seqlen, _ = x.shape # QKV