You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Enable per-call override of causal flag and window
Allows passing window size as an argument and forwards it instead of always using the module default.
Respects a provided causal flag from kwargs, falling back to the module value if absent.
Clarifies attention mask/bias shapes to include 2D masks and per-head forms.
Improves configurability and fixes ignored overrides.
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
30
31
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
31
32
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32
-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len).
33
-
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len).
33
+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34
+
(batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
35
+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
0 commit comments