|
1192 | 1192 | data_type : q |
1193 | 1193 |
|
1194 | 1194 | - backward_op : flash_attn_v3_varlen_grad |
1195 | | - forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse) |
1196 | | - args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, int max_seqlen_q, int max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin) |
| 1195 | + forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse) |
| 1196 | + args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin) |
1197 | 1197 | optional : seqused_q, seqused_k |
1198 | 1198 | output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) |
1199 | 1199 | infer_meta : |
|
0 commit comments