11from abc import ABC , abstractmethod
2- from typing import Any , Dict , Optional , Tuple , Type
2+ from typing import Any , Dict , Optional , Tuple , Type , TypedDict
33
44import torch
55import torch .nn as nn
88from executorch .examples .models .llama .rope import Rope
99
1010
11+ class ForwardOptions (TypedDict , total = False ):
12+ """Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""
13+
14+ mask : Optional [torch .Tensor ]
15+ input_pos : Optional [torch .Tensor ]
16+ in_cache_state : Optional [Any ]
17+ out_cache_state : Optional [Any ]
18+
19+
1120class Attention (nn .Module , ABC ):
1221 """Abstract base class for attention mechanisms with unified interface."""
1322
@@ -17,19 +26,14 @@ def forward(
1726 x : torch .Tensor ,
1827 freqs_cos : torch .Tensor ,
1928 freqs_sin : torch .Tensor ,
20- mask : Optional [torch .Tensor ] = None ,
21- input_pos : Optional [torch .Tensor ] = None ,
22- in_cache_state : Optional [Any ] = None ,
23- out_cache_state : Optional [Any ] = None ,
29+ ** kwargs : ForwardOptions ,
2430 ) -> Tuple [torch .Tensor , Optional [Any ]]:
2531 """Forward pass for attention mechanism.
2632
2733 Args:
2834 x: Input tensor of shape (batch_size, seq_len, dim)
2935 freqs_cos, freqs_sin: Rotary position embedding frequencies
30- mask: Optional attention mask
31- input_pos: Positions for KV cache updates
32- in_cache_state/out_cache_state: Cache states
36+ ForwardOptions: grouped optional args
3337
3438 Returns:
3539 Tuple of (output tensor, updated cache state)
@@ -209,11 +213,9 @@ def forward(
209213 x : torch .Tensor ,
210214 freqs_cos : torch .Tensor ,
211215 freqs_sin : torch .Tensor ,
212- mask : Optional [torch .Tensor ] = None ,
213- input_pos : Optional [torch .Tensor ] = None ,
214- in_cache_state : Optional [Any ] = None ,
215- out_cache_state : Optional [Any ] = None ,
216+ ** kwargs : ForwardOptions ,
216217 ) -> Tuple [torch .Tensor , Optional [Any ]]:
218+ input_pos = kwargs .get ("input_pos" )
217219 bsz , seqlen , _ = x .shape
218220
219221 # QKV
0 commit comments