Skip to content

Commit 433e3ba

Browse files
authored
add unique_count for dist_attn (#85)
1 parent 6e7e269 commit 433e3ba

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

magi_attention/functional/dist_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def attn_fwd_partial(
324324
**attn_arg.to_ffa_args(is_bwd=False),
325325
merge_q_ranges=None,
326326
qk_map=None,
327+
fwd_unique_count=None,
327328
softmax_scale=q.shape[-1] ** -0.5,
328329
deterministic=deterministic,
329330
softcap=0.0,
@@ -394,6 +395,7 @@ def attn_bwd_partial(
394395
**attn_arg.to_ffa_args(is_bwd=True),
395396
merge_k_ranges=None,
396397
bwd_kq_map=None,
398+
bwd_unique_count=None,
397399
softmax_scale=q.shape[-1] ** -0.5,
398400
deterministic=deterministic,
399401
softcap=0.0,

magi_attention/functional/flex_flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _flex_flash_attn_forward(
130130
attn_type_map,
131131
merge_q_ranges,
132132
qk_map,
133-
unique_count,
133+
fwd_unique_count,
134134
softmax_scale,
135135
softcap,
136136
disable_fwd_atomic_reduction,
@@ -157,7 +157,7 @@ def _flex_flash_attn_forward(
157157
attn_type_map,
158158
merge_q_ranges,
159159
qk_map,
160-
unique_count,
160+
fwd_unique_count,
161161
softmax_scale,
162162
softcap,
163163
disable_fwd_atomic_reduction,

0 commit comments

Comments
 (0)