Skip to content

Commit f5ebc35

Browse files
committed
Fixes attn-bias order; passes window size
Corrects parenthesis to apply the matrix scaling before transpose when building the attention bias, aligning with the intended formula and improving numerical stability/broadcasting. Passes window size into the attention kernel to enable proper windowed masking and behavior.
1 parent 071ab90 commit f5ebc35

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/modeling/modeling_doge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def forward(
218218
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
219219
)
220220
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
221-
attn_bias = self.A * F.softplus(dt_states).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
221+
attn_bias = (self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
222222

223223
attention_interface: Callable = flash_dynamic_mask_attention_forward
224224

@@ -230,6 +230,7 @@ def forward(
230230
attention_mask=attention_mask,
231231
attention_bias=attn_bias,
232232
scale=self.scaling,
233+
window_size=self.window_size,
233234
)
234235

235236
attn_output = attn_output.reshape(*input_shape, -1).contiguous()

0 commit comments

Comments
 (0)