Skip to content

Commit d209826

Browse files
committed
Refactor attention bias calculation for improved clarity and correctness
1 parent 337f552 commit d209826

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

flash_sparse_attn/modules/dynamic_mask_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def forward(
6565

6666
gate_states = self.g_proj(query_states)
6767
delta_states = self.d_proj(value_states)
68-
attn_bias = (torch.sigmoid(gate_states) * delta_states).transpose(-1, -2).unsqueeze(-2)
68+
attn_bias = torch.sigmoid(gate_states) * delta_states
6969

7070
query_states = query_states.view(bsz, seq_len, -1, self.head_dim)
7171
key_states = key_states.view(bsz, key_len, -1, self.head_dim)
7272
value_states = value_states.view(bsz, key_len, -1, self.head_dim)
73+
attn_bias = attn_bias.transpose(-1, -2).unsqueeze(-2)
7374

7475
attn_mask = create_mask(
7576
attention_bias=attn_bias,

0 commit comments

Comments
 (0)