@@ -30,10 +30,10 @@ def flash_dynamic_mask_attention_forward(
3030 query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3131 key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3232 value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
33- attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34- (batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0 }, key_len).
33+ attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34+ (batch_size, seq_len) or ({ batch_size|1} , {num_heads|num_kv_heads|1}, {query_len|1 }, { key_len|1} ).
3535 attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
36- (batch_size, {num_heads|num_kv_heads|1}, {query_len|0 }, key_len).
36+ ({ batch_size|1} , {num_heads|num_kv_heads|1}, {query_len|1 }, { key_len|1} ).
3737 scaling (Optional[float]): The scaling factor for the attention scores.
3838 window_size (Optional[int]): The size of the window to keep.
3939 softcap (Optional[float]): The softcap value for the attention scores.
0 commit comments