Skip to content

Commit a0d6ee5

Browse files
committed
Clarifies attention bias parameter type in docstring
Specifies that attention_bias parameter expects a float tensor to improve API documentation clarity and help developers understand the expected data type.
1 parent ab4513a commit a0d6ee5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def flash_dynamic_mask_attention_forward(
3030
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3131
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3232
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len).
33-
attention_bias (Optional[torch.Tensor]): The attention bias tensor of shape (batch_size, num_kv_heads, query_len, key_len).
33+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len).
3434
scaling (Optional[float]): The scaling factor for the attention scores.
3535
softcap (Optional[float]): The softcap value for the attention scores.
3636
**kwargs: Additional keyword arguments.

0 commit comments

Comments
 (0)