2020#include " paddle/extension.h"
2121#include " utils/utils.h"
2222
23+ #define SDPA_SET_FLAGS (condition, flag_name ) \
24+ if (condition) { \
25+ flags |= SdpaFlags_t::SDPA_FLAGS_##flag_name; \
26+ }
2327#define SDPA_SET_INPUT_AND_FLAGS (ptr, flag_name ) \
2428 if (ptr) { \
2529 flags |= SdpaFlags_t::SDPA_FLAGS_##flag_name; \
@@ -35,7 +39,7 @@ struct SDPAParams {
3539
3640class FusedFp8Sdpa : public HpuOperator {
3741 public:
38- FusedFp8Sdpa () : HpuOperator(" sdpa_recomp_fwd_hf8 " ) {}
42+ explicit FusedFp8Sdpa (std::string guid ) : HpuOperator(guid ) {}
3943 void AddNode (ConvertTensors& ct, SDPAParams& params) {
4044 auto inputs = ct.GetTensors ();
4145 auto outputs = ct.GetTensors (false );
@@ -67,12 +71,24 @@ class FusedFp8Sdpa : public HpuOperator {
6771 }
6872
6973 std::vector<synTensor> sync_outputs;
70- for (size_t i = 0 ; i < outputs.size (); i++) {
71- sync_outputs.push_back (createTensor (outputs[i].dims .size (),
72- outputs[i].type ,
73- outputs[i].dims ,
74- true ,
75- outputs[i].name ));
74+ // [0] out, bf16
75+ sync_outputs.push_back (createTensor (outputs[0 ].dims .size (),
76+ outputs[0 ].type ,
77+ outputs[0 ].dims ,
78+ true ,
79+ outputs[0 ].name ));
80+ if (params.params .flags & SdpaFlags_t::SDPA_FLAGS_AMAX_S) {
81+ // [1] m, bf16 [1]
82+ sync_outputs.push_back (createTensor (1 , syn_type_bf16, {1 }, false , " m" ));
83+ // [2] linv, float32 [1]
84+ sync_outputs.push_back (
85+ createTensor (1 , syn_type_float, {1 }, false , " linv" ));
86+ // [3] seed, int32 [1]
87+ sync_outputs.push_back (
88+ createTensor (1 , syn_type_int32, {1 }, false , " seed" ));
89+ // [4] amax_s, float32 [1]
90+ sync_outputs.push_back (
91+ createTensor (1 , syn_type_float, {1 }, true , outputs[1 ].name ));
7692 }
7793
7894 status = synNodeCreate (graphHandle_,
@@ -105,9 +121,13 @@ void fused_fp8_sdpa(const Context& dev_ctx,
105121 const paddle::optional<phi::DenseTensor>& d_scale_s,
106122 float scale,
107123 bool causal,
108- phi::DenseTensor* out) {
124+ bool is_amax_s,
125+ phi::DenseTensor* out,
126+ phi::DenseTensor* amax) {
109127 // allocate memory on device.
110128 dev_ctx.template Alloc <T>(out);
129+ dev_ctx.template Alloc <float >(amax);
130+
111131 if (out->numel () == 0 ) {
112132 return ;
113133 }
@@ -117,6 +137,7 @@ void fused_fp8_sdpa(const Context& dev_ctx,
117137 ct.Add (k);
118138 ct.Add (v);
119139
140+ std::string guid = " sdpa_recomp_fwd_hf8" ;
120141 unsigned int flags = 0 ;
121142
122143 SDPA_SET_INPUT_AND_FLAGS (d_scale_q.get_ptr (), D_SCALE_Q)
@@ -125,6 +146,10 @@ void fused_fp8_sdpa(const Context& dev_ctx,
125146 SDPA_SET_INPUT_AND_FLAGS (q_scale_s.get_ptr (), Q_SCALE_S)
126147 SDPA_SET_INPUT_AND_FLAGS (q_scale_o.get_ptr (), Q_SCALE_O)
127148 SDPA_SET_INPUT_AND_FLAGS (d_scale_s.get_ptr (), D_SCALE_S)
149+ if (flags == 0 ) {
150+ guid = " sdpa_recomp_fwd_bf16" ;
151+ }
152+ SDPA_SET_FLAGS (is_amax_s, AMAX_S)
128153
129154 SDPAParams params{};
130155
@@ -141,6 +166,8 @@ void fused_fp8_sdpa(const Context& dev_ctx,
141166 params.params .flags = flags;
142167
143168 ct.Add (*out, false );
169+ ct.Add (*amax, false );
170+
144171 std::vector<DIMS> inputs_dims = ct.GetDims ();
145172
146173 OpCacheOperator op_info;
@@ -149,7 +176,7 @@ void fused_fp8_sdpa(const Context& dev_ctx,
149176 auto recipe = op_info.GetRecipe ();
150177
151178 if (recipe == nullptr ) {
152- FusedFp8Sdpa op;
179+ FusedFp8Sdpa op (guid) ;
153180 op.AddNode (ct, params);
154181 op.Compile ();
155182 op_info.setOp (op);
@@ -175,7 +202,8 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
175202 const paddle::optional<paddle::Tensor>& q_scale_o,
176203 const paddle::optional<paddle::Tensor>& d_scale_s,
177204 bool causal,
178- float scale) {
205+ float scale,
206+ bool is_amax_s) {
179207 auto dev_ctx = static_cast <const phi::CustomContext*>(
180208 paddle::experimental::DeviceContextPool::Instance ().Get (q.place ()));
181209
@@ -242,6 +270,9 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
242270 auto out_tensor = std::make_shared<phi::DenseTensor>();
243271 out_tensor->Resize (q_tensor->dims ());
244272
273+ auto amax_tensor = std::make_shared<phi::DenseTensor>();
274+ amax_tensor->Resize ({1 });
275+
245276 custom_kernel::fused_fp8_sdpa<phi::dtype::bfloat16>(
246277 *dev_ctx,
247278 *q_tensor,
@@ -256,11 +287,11 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
256287 d_scale_s ? *d_scale_s_tensor : paddle::optional<phi::DenseTensor>(),
257288 scale,
258289 causal,
259- out_tensor. get ());
260-
261- paddle::Tensor out (out_tensor );
290+ is_amax_s,
291+ out_tensor. get (),
292+ amax_tensor. get () );
262293
263- return {out };
294+ return {paddle::Tensor (out_tensor), paddle::Tensor (amax_tensor) };
264295}
265296
266297std::vector<std::vector<int64_t >> FusedFp8SdpaForwardShape (
@@ -271,7 +302,7 @@ std::vector<std::vector<int64_t>> FusedFp8SdpaForwardShape(
271302 int64_t num_heads = query_states_shape[1 ];
272303 int64_t seq_len = query_states_shape[2 ];
273304 int head_dim = query_states_shape[3 ];
274- return {{bsz, num_heads, seq_len, head_dim}};
305+ return {{bsz, num_heads, seq_len, head_dim}, { 1 } };
275306}
276307
277308std::vector<paddle::DataType> FusedFp8SdpaForwardDtype (
@@ -294,8 +325,8 @@ PD_BUILD_OP(fused_fp8_sdpa)
294325 paddle::Optional (" q_scale_o" ),
295326 paddle::Optional (" d_scale_s" ),
296327 })
297- .Attrs({" causal: bool" , " scaling_factor: float" })
298- .Outputs({" out" })
328+ .Attrs({" causal: bool" , " scaling_factor: float" , " is_amax_s: bool " })
329+ .Outputs({" out" , " amax " })
299330 .SetKernelFn(PD_KERNEL(FusedFp8SdpaForward))
300331 .SetInferShapeFn(PD_INFER_SHAPE(FusedFp8SdpaForwardShape))
301332 .SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8SdpaForwardDtype));
0 commit comments