Skip to content

Commit fecfacd

Browse files
authored
[PIR] Promote flash_attn_v3_varlen argument (position 12) from int to Scalar to support pir.Value
1 parent feaccbb commit fecfacd

File tree

6 files changed

+33
-19
lines changed

6 files changed

+33
-19
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_grad_kernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
707707
const paddle::optional<DenseTensor> &seqused_k,
708708
const DenseTensor &out_grad,
709709
float const softmax_scale,
710-
int const max_seqlen_q,
711-
int const max_seqlen_k,
710+
const Scalar &max_seqlen_q,
711+
const Scalar &max_seqlen_k,
712712
bool const causal,
713713
int const window_size_left,
714714
int const window_size_right,
@@ -756,6 +756,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
756756
DenseTensor dq_accum;
757757
DenseTensor dk_accum;
758758
DenseTensor dv_accum;
759+
const int64_t max_seqlen_q_ = max_seqlen_q.to<int64_t>();
760+
const int64_t max_seqlen_k_ = max_seqlen_k.to<int64_t>();
759761
FlashAttnV3GradBaseKernel<T, Context>(dev_ctx,
760762
out_grad,
761763
q,
@@ -770,8 +772,8 @@ void FlashAttnV3VarlenGradKernel(const Context &dev_ctx,
770772
cu_seqlens_k,
771773
seqused_q,
772774
seqused_k,
773-
max_seqlen_q,
774-
max_seqlen_k,
775+
max_seqlen_q_,
776+
max_seqlen_k_,
775777
softmax_scale,
776778
causal,
777779
window_size_left,

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,8 +1082,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,
10821082
const paddle::optional<DenseTensor> &q_descale,
10831083
const paddle::optional<DenseTensor> &k_descale,
10841084
const paddle::optional<DenseTensor> &v_descale,
1085-
const int max_seqlen_q,
1086-
const int max_seqlen_k,
1085+
const Scalar &max_seqlen_q,
1086+
const Scalar &max_seqlen_k,
10871087
const float softmax_scale,
10881088
const bool causal,
10891089
const int window_size_left,
@@ -1150,6 +1150,8 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,
11501150

11511151
DenseTensor out_accum;
11521152
DenseTensor softmax_lse_accum;
1153+
const int64_t max_seqlen_q_ = max_seqlen_q.to<int64_t>();
1154+
const int64_t max_seqlen_k_ = max_seqlen_k.to<int64_t>();
11531155
FlashAttnV3BaseKernel<T, Context>(dev_ctx,
11541156
q,
11551157
k,
@@ -1171,9 +1173,9 @@ void FlashAttnV3VarlenKernel(const Context &dev_ctx,
11711173
q_descale,
11721174
k_descale,
11731175
v_descale,
1174-
paddle::none, // scheduler_metadata
1175-
max_seqlen_q, // max_seqlen_q_
1176-
max_seqlen_k, // max_seqlen_k_
1176+
paddle::none, // scheduler_metadata
1177+
max_seqlen_q_, // max_seqlen_q_
1178+
max_seqlen_k_, // max_seqlen_k_
11771179
softmax_scale,
11781180
causal,
11791181
window_size_left,

paddle/phi/ops/yaml/backward.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,8 +1192,8 @@
11921192
data_type : q
11931193

11941194
- backward_op : flash_attn_v3_varlen_grad
1195-
forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse)
1196-
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, int max_seqlen_q, int max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin)
1195+
forward : flash_attn_v3_varlen(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin) -> Tensor(out), Tensor(softmax_lse)
1196+
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor out_grad, float softmax_scale, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, int window_size_left, int window_size_right, float softcap, int sm_margin)
11971197
optional : seqused_q, seqused_k
11981198
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
11991199
infer_meta :

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,16 @@
13371337
data_type : int64_t
13381338
support_tensor : true
13391339

1340+
- op : flash_attn_v3_varlen
1341+
backward : flash_attn_v3_varlen_grad
1342+
scalar :
1343+
max_seqlen_q :
1344+
data_type : int64_t
1345+
support_tensor : true
1346+
max_seqlen_k :
1347+
data_type : int64_t
1348+
support_tensor : true
1349+
13401350
- op : flash_attn_varlen_qkvpacked
13411351
backward : flash_attn_varlen_qkvpacked_grad
13421352
scalar :

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2113,7 +2113,7 @@
21132113
backward : flash_attn_v3_grad
21142114

21152115
- op : flash_attn_v3_varlen
2116-
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, int max_seqlen_q, int max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin)
2116+
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor seqused_q, Tensor seqused_k, Tensor qv, Tensor q_descale, Tensor k_descale, Tensor v_descale, Scalar max_seqlen_q, Scalar max_seqlen_k, float softmax_scale, bool causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa, int sm_margin)
21172117
output : Tensor(out), Tensor(softmax_lse)
21182118
optional : seqused_q, seqused_k, qv, q_descale, k_descale, v_descale
21192119
infer_meta :

python/paddle/utils/decorator_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_RetT = TypeVar("_RetT")
3131

3232

33-
def _is_in_or_scalar_tensor(x):
33+
def _is_int_or_scalar_tensor(x):
3434
if isinstance(x, int):
3535
return True
3636
if isinstance(x, (paddle.Tensor, paddle.pir.Value)):
@@ -420,8 +420,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
420420
kwargs["shape_or_dtype"] = kwargs.pop("dtype")
421421
elif ("size" in kwargs) and ("shape_or_dtype" not in kwargs):
422422
kwargs["shape_or_dtype"] = kwargs.pop("size")
423-
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
424-
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
423+
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
424+
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
425425
kwargs["x"] = args[0]
426426
kwargs['shape_or_dtype'] = list(args[1:])
427427
args = ()
@@ -552,8 +552,8 @@ def decorator(func: Callable[_InputT, _RetT]) -> Callable[_InputT, _RetT]:
552552
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
553553
if ("input" in kwargs) and ("x" not in kwargs):
554554
kwargs["x"] = kwargs.pop("input")
555-
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
556-
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
555+
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
556+
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
557557
kwargs["x"] = args[0]
558558
kwargs['shape'] = list(args[1:])
559559
args = ()
@@ -624,8 +624,8 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
624624
kwargs["x"] = kwargs.pop("input")
625625
if ("size" in kwargs) and ("shape" not in kwargs):
626626
kwargs["shape"] = kwargs.pop("size")
627-
elif len(args) >= 2 and _is_in_or_scalar_tensor(args[1]):
628-
if all(_is_in_or_scalar_tensor(arg) for arg in args[1:]):
627+
elif len(args) >= 2 and _is_int_or_scalar_tensor(args[1]):
628+
if all(_is_int_or_scalar_tensor(arg) for arg in args[1:]):
629629
kwargs["x"] = args[0]
630630
kwargs['shape'] = list(args[1:])
631631
args = ()

0 commit comments

Comments
 (0)