@@ -29,13 +29,118 @@ void TransposeKernel(const Context &dev_ctx,
29
29
30
30
class FSDPA : public HpuOperator {
31
31
public:
32
- explicit FSDPA (std::string guid_prefix) : HpuOperator(guid_prefix, false ) {}
32
+ explicit FSDPA (std::string guid_prefix, synDataType dtype)
33
+ : HpuOperator(guid_prefix), dtype_(dtype) {}
33
34
void AddNode (ConvertTensors &ct, ns_Sdpa::ParamsV2 params) {
34
35
auto inputs = ct.GetTensors ();
35
36
auto outputs = ct.GetTensors (false );
36
37
38
+ std::vector<int64_t > q_dims = std::vector<int64_t >(inputs[0 ].dims );
39
+ std::vector<int64_t > qt_dims (q_dims.cbegin (), q_dims.cend ());
40
+ std::vector<int64_t > kv_dims = std::vector<int64_t >(inputs[1 ].dims );
41
+ std::vector<int64_t > kvt_dims (kv_dims.cbegin (), kv_dims.cend ());
42
+
43
+ int rank = q_dims.size ();
44
+
45
+ std::vector<int > axis = {0 , 2 , 1 , 3 };
46
+ synTransposeParams trans_params;
47
+ for (size_t i = 0 ; i < axis.size (); i++) {
48
+ trans_params.permutation [i] =
49
+ static_cast <TransposePermutationDim>(axis[i]);
50
+ }
51
+ trans_params.tensorDim = rank;
52
+
53
+ qt_dims[rank - 3 ] = q_dims[rank - 2 ];
54
+ qt_dims[rank - 2 ] = q_dims[rank - 3 ];
55
+ kvt_dims[rank - 3 ] = kv_dims[rank - 2 ];
56
+ kvt_dims[rank - 2 ] = kv_dims[rank - 3 ];
57
+
58
+ synTensor q_transpose_inputs[1 ] = {createTensor (inputs[0 ].dims .size (),
59
+ inputs[0 ].type ,
60
+ inputs[0 ].dims ,
61
+ true ,
62
+ inputs[0 ].name )};
63
+
64
+ synTensor q_transpose_outputs[1 ] = {createTensor (
65
+ inputs[0 ].dims .size (), inputs[0 ].type , qt_dims, false , " q_t" )};
66
+
67
+ synTensor k_transpose_inputs[1 ] = {createTensor (inputs[1 ].dims .size (),
68
+ inputs[1 ].type ,
69
+ inputs[1 ].dims ,
70
+ true ,
71
+ inputs[1 ].name )};
72
+
73
+ synTensor k_transpose_outputs[1 ] = {createTensor (
74
+ inputs[1 ].dims .size (), inputs[1 ].type , kvt_dims, false , " k_t" )};
75
+
76
+ synTensor v_transpose_inputs[1 ] = {createTensor (inputs[2 ].dims .size (),
77
+ inputs[2 ].type ,
78
+ inputs[2 ].dims ,
79
+ true ,
80
+ inputs[2 ].name )};
81
+
82
+ synTensor v_transpose_outputs[1 ] = {createTensor (
83
+ inputs[2 ].dims .size (), inputs[2 ].type , kvt_dims, false , " v_t" )};
84
+
85
+ std::string trans = " transpose" ;
86
+ if (dtype_ == syn_type_fp16) {
87
+ trans = trans + " _f16" ;
88
+ } else if (dtype_ == syn_type_bf16) {
89
+ trans = trans + " _bf16" ;
90
+ } else if (dtype_ == syn_type_single) {
91
+ trans = trans + " _f32" ;
92
+ }
93
+
94
+ synStatus status = synNodeCreate (graphHandle_,
95
+ q_transpose_inputs,
96
+ q_transpose_outputs,
97
+ 1 ,
98
+ 1 ,
99
+ &trans_params,
100
+ sizeof (trans_params),
101
+ trans.c_str (),
102
+ " q_transpose" ,
103
+ nullptr ,
104
+ nullptr );
105
+ PD_CHECK (status == synSuccess,
106
+ " [RUNTIME] FSDPA q_transpose synNodeCreate () failed = " ,
107
+ status);
108
+
109
+ status = synNodeCreate (graphHandle_,
110
+ k_transpose_inputs,
111
+ k_transpose_outputs,
112
+ 1 ,
113
+ 1 ,
114
+ &trans_params,
115
+ sizeof (trans_params),
116
+ trans.c_str (),
117
+ " k_transpose" ,
118
+ nullptr ,
119
+ nullptr );
120
+ PD_CHECK (status == synSuccess,
121
+ " [RUNTIME] FSDPA k_transpose synNodeCreate () failed = " ,
122
+ status);
123
+
124
+ status = synNodeCreate (graphHandle_,
125
+ v_transpose_inputs,
126
+ v_transpose_outputs,
127
+ 1 ,
128
+ 1 ,
129
+ &trans_params,
130
+ sizeof (trans_params),
131
+ trans.c_str (),
132
+ " v_transpose" ,
133
+ nullptr ,
134
+ nullptr );
135
+ PD_CHECK (status == synSuccess,
136
+ " [RUNTIME] FSDPA v_transpose synNodeCreate () failed = " ,
137
+ status);
138
+
37
139
std::vector<synTensor> syn_inputs;
38
- for (size_t i = 0 ; i < inputs.size (); i++) {
140
+ syn_inputs.push_back (q_transpose_outputs[0 ]);
141
+ syn_inputs.push_back (k_transpose_outputs[0 ]);
142
+ syn_inputs.push_back (v_transpose_outputs[0 ]);
143
+ for (size_t i = 3 ; i < inputs.size (); i++) {
39
144
syn_inputs.push_back (createTensor (inputs[i].dims .size (),
40
145
inputs[i].type ,
41
146
inputs[i].dims ,
@@ -44,13 +149,11 @@ class FSDPA : public HpuOperator {
44
149
}
45
150
46
151
std::vector<synTensor> syn_outputs;
47
- for (size_t i = 0 ; i < 1 ; i++) {
48
- syn_outputs.push_back (createTensor (outputs[i].dims .size (),
49
- outputs[i].type ,
50
- outputs[i].dims ,
51
- true ,
52
- outputs[i].name ));
53
- }
152
+
153
+ synTensor attn_outputs[1 ] = {createTensor (
154
+ inputs[0 ].dims .size (), inputs[0 ].type , qt_dims, false , " attn_t" )};
155
+ syn_outputs.push_back (attn_outputs[0 ]);
156
+
54
157
if (!params.is_inference ) {
55
158
for (size_t i = 1 ; i < outputs.size (); i++) {
56
159
syn_outputs.push_back (createTensor (outputs[i].dims .size (),
@@ -61,20 +164,46 @@ class FSDPA : public HpuOperator {
61
164
}
62
165
}
63
166
64
- synStatus status = synNodeCreate (graphHandle_,
65
- syn_inputs.data (),
66
- syn_outputs.data (),
67
- syn_inputs.size (),
68
- syn_outputs.size (),
69
- ¶ms,
70
- sizeof (params),
71
- guid_.c_str (),
72
- " FSDPA" ,
73
- nullptr ,
74
- nullptr );
75
- PD_CHECK (
76
- status == synSuccess, " [RUNTIME] synNodeCreate () failed = %d" , status);
167
+ status = synNodeCreate (graphHandle_,
168
+ syn_inputs.data (),
169
+ syn_outputs.data (),
170
+ syn_inputs.size (),
171
+ syn_outputs.size (),
172
+ ¶ms,
173
+ sizeof (params),
174
+ guid_.c_str (),
175
+ " FSDPA" ,
176
+ nullptr ,
177
+ nullptr );
178
+ PD_CHECK (status == synSuccess,
179
+ " [RUNTIME] FSDPA sdpa_recomp_fwd synNodeCreate () failed = " ,
180
+ status);
181
+
182
+ synTensor attn_transpose_outputs[1 ] = {createTensor (outputs[0 ].dims .size (),
183
+ outputs[0 ].type ,
184
+ outputs[0 ].dims ,
185
+ true ,
186
+ outputs[0 ].name )};
187
+
188
+ status = synNodeCreate (graphHandle_,
189
+ attn_outputs,
190
+ attn_transpose_outputs,
191
+ 1 ,
192
+ 1 ,
193
+ &trans_params,
194
+ sizeof (trans_params),
195
+ trans.c_str (),
196
+ " attn_transpose" ,
197
+ nullptr ,
198
+ nullptr );
199
+
200
+ PD_CHECK (status == synSuccess,
201
+ " [RUNTIME] FSDPA attn_transpose synNodeCreate () failed = " ,
202
+ status);
77
203
}
204
+
205
+ protected:
206
+ synDataType dtype_;
78
207
};
79
208
80
209
template <typename T, typename Context>
@@ -83,61 +212,29 @@ void FusedDotProductAttentionKernel(
83
212
const phi::DenseTensor &q,
84
213
const phi::DenseTensor &k,
85
214
const phi::DenseTensor &v,
86
- const phi::DenseTensor &mask,
87
- // const paddle::optional<phi::DenseTensor> &attention_mask,
88
- // const paddle::optional<phi::DenseTensor> &cu_seqlen_q,
89
- // const paddle::optional<phi::DenseTensor> &cu_seqlen_kv,
215
+ const paddle::optional<phi::DenseTensor> &attention_mask,
216
+ const paddle::optional<phi::DenseTensor> &cu_seqlen_q,
217
+ const paddle::optional<phi::DenseTensor> &cu_seqlen_kv,
90
218
float scaling_factor,
91
219
float dropout_probability,
92
220
bool is_training,
93
- bool is_causal_masking,
94
- // const std::string &mask_type_str,
95
- // const std::string &bias_type_str,
221
+ const std::string &mask_type_str,
222
+ const std::string &bias_type_str,
96
223
phi::DenseTensor *out,
97
224
phi::DenseTensor *softmax_out,
98
225
phi::DenseTensor *rng_state) {
99
- std::vector<int > axis = {0 , 2 , 1 , 3 };
100
- phi::DenseTensor qt;
101
- // auto q_dims = q.dims();
102
- std::vector<int64_t > q_dims = phi::vectorize<int64_t >(q.dims ());
103
- std::vector<int64_t > qt_dims (q_dims.cbegin (), q_dims.cend ());
104
-
105
- int rank = q_dims.size ();
106
- qt_dims[rank - 3 ] = q_dims[rank - 2 ];
107
- qt_dims[rank - 2 ] = q_dims[rank - 3 ];
108
-
109
- phi::DenseTensorMeta qt_meta ({q.dtype (), phi::make_ddim (qt_dims)});
110
- qt.set_meta (qt_meta);
111
- custom_kernel::TransposeKernel<T, Context>(dev_ctx, q, axis, &qt);
112
-
113
- phi::DenseTensor kt;
114
- phi::DenseTensor vt;
115
- std::vector<int64_t > kv_dims = phi::vectorize<int64_t >(k.dims ());
116
- std::vector<int64_t > kvt_dims (kv_dims.cbegin (), kv_dims.cend ());
117
- kvt_dims[rank - 3 ] = kv_dims[rank - 2 ];
118
- kvt_dims[rank - 2 ] = kv_dims[rank - 3 ];
119
- phi::DenseTensorMeta kvt_meta ({k.dtype (), phi::make_ddim (kvt_dims)});
120
- kt.set_meta (kvt_meta);
121
- vt.set_meta (kvt_meta);
122
- custom_kernel::TransposeKernel<T, Context>(dev_ctx, k, axis, &kt);
123
- custom_kernel::TransposeKernel<T, Context>(dev_ctx, v, axis, &vt);
124
-
125
- out->Resize (phi::make_ddim (qt_dims));
126
226
dev_ctx.template Alloc <T>(out);
127
227
if (is_training) {
128
228
dev_ctx.template Alloc <T>(softmax_out);
129
229
}
130
230
131
231
ConvertTensors ct;
132
- ct.Add (qt);
133
- ct.Add (kt);
134
- ct.Add (vt);
135
- ct.Add (mask);
136
- /*
232
+ ct.Add (q);
233
+ ct.Add (k);
234
+ ct.Add (v);
137
235
if (attention_mask.get_ptr ()) {
138
236
ct.Add (attention_mask.get_ptr ());
139
237
}
140
- */
141
238
ct.Add (out, false );
142
239
if (is_training) {
143
240
ct.Add (softmax_out, false );
@@ -149,8 +246,7 @@ void FusedDotProductAttentionKernel(
149
246
ns_Sdpa::ParamsV2 params;
150
247
memset (reinterpret_cast <void *>(¶ms), 0x00 , sizeof (ns_Sdpa::ParamsV2));
151
248
params.scale = scaling_factor;
152
- params.is_causal = is_causal_masking;
153
- // params.is_causal = (mask_type_str == "causal");
249
+ params.is_causal = (mask_type_str == " causal" );
154
250
params.dropout .ratio = dropout_probability;
155
251
params.dropout .disableMaskOut = false ;
156
252
params.is_inference = !is_training;
@@ -163,7 +259,7 @@ void FusedDotProductAttentionKernel(
163
259
auto recipe = op_info.GetRecipe ();
164
260
165
261
if (recipe == nullptr ) {
166
- FSDPA op (op_info.guid_ );
262
+ FSDPA op (op_info.guid_ , op_info. datatype_ );
167
263
168
264
op.AddNode (ct, params);
169
265
op.Compile ();
0 commit comments