File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
magi_attention/functional Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -324,6 +324,7 @@ def attn_fwd_partial(
324324 ** attn_arg .to_ffa_args (is_bwd = False ),
325325 merge_q_ranges = None ,
326326 qk_map = None ,
327+ fwd_unique_count = None ,
327328 softmax_scale = q .shape [- 1 ] ** - 0.5 ,
328329 deterministic = deterministic ,
329330 softcap = 0.0 ,
@@ -394,6 +395,7 @@ def attn_bwd_partial(
394395 ** attn_arg .to_ffa_args (is_bwd = True ),
395396 merge_k_ranges = None ,
396397 bwd_kq_map = None ,
398+ bwd_unique_count = None ,
397399 softmax_scale = q .shape [- 1 ] ** - 0.5 ,
398400 deterministic = deterministic ,
399401 softcap = 0.0 ,
Original file line number Diff line number Diff line change @@ -130,7 +130,7 @@ def _flex_flash_attn_forward(
130130 attn_type_map ,
131131 merge_q_ranges ,
132132 qk_map ,
133- unique_count ,
133+ fwd_unique_count ,
134134 softmax_scale ,
135135 softcap ,
136136 disable_fwd_atomic_reduction ,
@@ -157,7 +157,7 @@ def _flex_flash_attn_forward(
157157 attn_type_map ,
158158 merge_q_ranges ,
159159 qk_map ,
160- unique_count ,
160+ fwd_unique_count ,
161161 softmax_scale ,
162162 softcap ,
163163 disable_fwd_atomic_reduction ,
You can’t perform that action at this time.
0 commit comments