@@ -65,9 +65,9 @@ def create_mask(
6565
6666 Args:
6767 attention_bias (torch.Tensor): The attention bias tensor of shape
68- ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, { key_len|1} ).
69- attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
70- (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, { key_len|1} ).
68+ ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
69+ attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
70+ (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
7171 batch_size (int): The batch size.
7272 query_len (int): The sequence length of the query.
7373 key_len (int): The sequence length of the key.
@@ -76,7 +76,7 @@ def create_mask(
7676
7777 Returns:
7878 attention (Tensor): The attention mask tensor of shape
79- ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, { key_len|1} ).
79+ ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
8080 """
8181
8282 # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)
0 commit comments