Skip to content

Commit 9cc0ca6

Browse files
authored
Merge pull request #177 from SmallDoges/update-example
Refactor attention mask and bias handling for efficiency
2 parents ad7a3ab + 80b25f7 commit 9cc0ca6

File tree

2 files changed

+10
-13
lines changed

2 files changed

+10
-13
lines changed

examples/modeling/modeling_doge.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,12 @@ def forward(
241241
dt_states = self.dt_proj(
242242
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
243243
)
244-
dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
245-
attn_bias = dt_states[:, :, None, :].expand(
246-
-1, -1, hidden_states.shape[1], -1
247-
).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len]
244+
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
248245

249246
attention_interface: Callable = eager_attention_forward
250247
if flash_dynamic_mask_attention_forward is not None:
251248
attention_interface = flash_dynamic_mask_attention_forward
252249

253-
attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None # attention_mask: batch, num_kv_heads, query_len, key_len
254250
attn_output, attn_weights = attention_interface(
255251
self,
256252
query_states,
@@ -414,7 +410,7 @@ def _init_weights(self, module):
414410
super()._init_weights(module)
415411
if isinstance(module, DogeAttention):
416412
if hasattr(module, "A"):
417-
module.A.data.zero_()
413+
module.A.data.normal_(mean=0.0, std=self.config.initializer_range)
418414
elif isinstance(module, DogeCDMoE):
419415
if hasattr(module, "router_gate"):
420416
module.router_gate.weight.data.zero_()

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@ def _flash_dynamic_mask_attention_forward(
9393
min_dtype
9494
)
9595

96-
if keep_window_size is not None:
97-
if key_length > keep_window_size:
98-
topk_values, topk_indices = torch.topk(
99-
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
100-
)
101-
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
102-
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
96+
if keep_window_size is not None and key_length > keep_window_size:
97+
topk_values, topk_indices = torch.topk(
98+
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
99+
)
100+
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
101+
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
102+
else:
103+
attention_mask = None
103104

104105
out = flash_fn(
105106
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal

0 commit comments

Comments
 (0)