Skip to content

Commit 7c4b102

Browse files
committed
Add unsqueeze to attention bias computation in DogeAttention
1 parent 3d91162 commit 7c4b102

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/modeling/modeling_doge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def forward(
217217
dt_states = self.dt_proj(
218218
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
219219
)
220-
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
220+
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
221221

222222
attention_interface: Callable = flash_dynamic_mask_attention_forward
223223

0 commit comments

Comments
 (0)