Skip to content

Commit 72fcaca

Browse files
authored
[Intel_HPU]Fused SDPA API update (#1494)
1 parent 453da78 commit 72fcaca

File tree

1 file changed

+161
-65
lines changed

1 file changed

+161
-65
lines changed

backends/intel_hpu/kernels/sdpa.cc

Lines changed: 161 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,118 @@ void TransposeKernel(const Context &dev_ctx,
2929

3030
class FSDPA : public HpuOperator {
3131
public:
32-
explicit FSDPA(std::string guid_prefix) : HpuOperator(guid_prefix, false) {}
32+
explicit FSDPA(std::string guid_prefix, synDataType dtype)
33+
: HpuOperator(guid_prefix), dtype_(dtype) {}
3334
void AddNode(ConvertTensors &ct, ns_Sdpa::ParamsV2 params) {
3435
auto inputs = ct.GetTensors();
3536
auto outputs = ct.GetTensors(false);
3637

38+
std::vector<int64_t> q_dims = std::vector<int64_t>(inputs[0].dims);
39+
std::vector<int64_t> qt_dims(q_dims.cbegin(), q_dims.cend());
40+
std::vector<int64_t> kv_dims = std::vector<int64_t>(inputs[1].dims);
41+
std::vector<int64_t> kvt_dims(kv_dims.cbegin(), kv_dims.cend());
42+
43+
int rank = q_dims.size();
44+
45+
std::vector<int> axis = {0, 2, 1, 3};
46+
synTransposeParams trans_params;
47+
for (size_t i = 0; i < axis.size(); i++) {
48+
trans_params.permutation[i] =
49+
static_cast<TransposePermutationDim>(axis[i]);
50+
}
51+
trans_params.tensorDim = rank;
52+
53+
qt_dims[rank - 3] = q_dims[rank - 2];
54+
qt_dims[rank - 2] = q_dims[rank - 3];
55+
kvt_dims[rank - 3] = kv_dims[rank - 2];
56+
kvt_dims[rank - 2] = kv_dims[rank - 3];
57+
58+
synTensor q_transpose_inputs[1] = {createTensor(inputs[0].dims.size(),
59+
inputs[0].type,
60+
inputs[0].dims,
61+
true,
62+
inputs[0].name)};
63+
64+
synTensor q_transpose_outputs[1] = {createTensor(
65+
inputs[0].dims.size(), inputs[0].type, qt_dims, false, "q_t")};
66+
67+
synTensor k_transpose_inputs[1] = {createTensor(inputs[1].dims.size(),
68+
inputs[1].type,
69+
inputs[1].dims,
70+
true,
71+
inputs[1].name)};
72+
73+
synTensor k_transpose_outputs[1] = {createTensor(
74+
inputs[1].dims.size(), inputs[1].type, kvt_dims, false, "k_t")};
75+
76+
synTensor v_transpose_inputs[1] = {createTensor(inputs[2].dims.size(),
77+
inputs[2].type,
78+
inputs[2].dims,
79+
true,
80+
inputs[2].name)};
81+
82+
synTensor v_transpose_outputs[1] = {createTensor(
83+
inputs[2].dims.size(), inputs[2].type, kvt_dims, false, "v_t")};
84+
85+
std::string trans = "transpose";
86+
if (dtype_ == syn_type_fp16) {
87+
trans = trans + "_f16";
88+
} else if (dtype_ == syn_type_bf16) {
89+
trans = trans + "_bf16";
90+
} else if (dtype_ == syn_type_single) {
91+
trans = trans + "_f32";
92+
}
93+
94+
synStatus status = synNodeCreate(graphHandle_,
95+
q_transpose_inputs,
96+
q_transpose_outputs,
97+
1,
98+
1,
99+
&trans_params,
100+
sizeof(trans_params),
101+
trans.c_str(),
102+
"q_transpose",
103+
nullptr,
104+
nullptr);
105+
PD_CHECK(status == synSuccess,
106+
"[RUNTIME] FSDPA q_transpose synNodeCreate () failed = ",
107+
status);
108+
109+
status = synNodeCreate(graphHandle_,
110+
k_transpose_inputs,
111+
k_transpose_outputs,
112+
1,
113+
1,
114+
&trans_params,
115+
sizeof(trans_params),
116+
trans.c_str(),
117+
"k_transpose",
118+
nullptr,
119+
nullptr);
120+
PD_CHECK(status == synSuccess,
121+
"[RUNTIME] FSDPA k_transpose synNodeCreate () failed = ",
122+
status);
123+
124+
status = synNodeCreate(graphHandle_,
125+
v_transpose_inputs,
126+
v_transpose_outputs,
127+
1,
128+
1,
129+
&trans_params,
130+
sizeof(trans_params),
131+
trans.c_str(),
132+
"v_transpose",
133+
nullptr,
134+
nullptr);
135+
PD_CHECK(status == synSuccess,
136+
"[RUNTIME] FSDPA v_transpose synNodeCreate () failed = ",
137+
status);
138+
37139
std::vector<synTensor> syn_inputs;
38-
for (size_t i = 0; i < inputs.size(); i++) {
140+
syn_inputs.push_back(q_transpose_outputs[0]);
141+
syn_inputs.push_back(k_transpose_outputs[0]);
142+
syn_inputs.push_back(v_transpose_outputs[0]);
143+
for (size_t i = 3; i < inputs.size(); i++) {
39144
syn_inputs.push_back(createTensor(inputs[i].dims.size(),
40145
inputs[i].type,
41146
inputs[i].dims,
@@ -44,13 +149,11 @@ class FSDPA : public HpuOperator {
44149
}
45150

46151
std::vector<synTensor> syn_outputs;
47-
for (size_t i = 0; i < 1; i++) {
48-
syn_outputs.push_back(createTensor(outputs[i].dims.size(),
49-
outputs[i].type,
50-
outputs[i].dims,
51-
true,
52-
outputs[i].name));
53-
}
152+
153+
synTensor attn_outputs[1] = {createTensor(
154+
inputs[0].dims.size(), inputs[0].type, qt_dims, false, "attn_t")};
155+
syn_outputs.push_back(attn_outputs[0]);
156+
54157
if (!params.is_inference) {
55158
for (size_t i = 1; i < outputs.size(); i++) {
56159
syn_outputs.push_back(createTensor(outputs[i].dims.size(),
@@ -61,20 +164,46 @@ class FSDPA : public HpuOperator {
61164
}
62165
}
63166

64-
synStatus status = synNodeCreate(graphHandle_,
65-
syn_inputs.data(),
66-
syn_outputs.data(),
67-
syn_inputs.size(),
68-
syn_outputs.size(),
69-
&params,
70-
sizeof(params),
71-
guid_.c_str(),
72-
"FSDPA",
73-
nullptr,
74-
nullptr);
75-
PD_CHECK(
76-
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
167+
status = synNodeCreate(graphHandle_,
168+
syn_inputs.data(),
169+
syn_outputs.data(),
170+
syn_inputs.size(),
171+
syn_outputs.size(),
172+
&params,
173+
sizeof(params),
174+
guid_.c_str(),
175+
"FSDPA",
176+
nullptr,
177+
nullptr);
178+
PD_CHECK(status == synSuccess,
179+
"[RUNTIME] FSDPA sdpa_recomp_fwd synNodeCreate () failed = ",
180+
status);
181+
182+
synTensor attn_transpose_outputs[1] = {createTensor(outputs[0].dims.size(),
183+
outputs[0].type,
184+
outputs[0].dims,
185+
true,
186+
outputs[0].name)};
187+
188+
status = synNodeCreate(graphHandle_,
189+
attn_outputs,
190+
attn_transpose_outputs,
191+
1,
192+
1,
193+
&trans_params,
194+
sizeof(trans_params),
195+
trans.c_str(),
196+
"attn_transpose",
197+
nullptr,
198+
nullptr);
199+
200+
PD_CHECK(status == synSuccess,
201+
"[RUNTIME] FSDPA attn_transpose synNodeCreate () failed = ",
202+
status);
77203
}
204+
205+
protected:
206+
synDataType dtype_;
78207
};
79208

80209
template <typename T, typename Context>
@@ -83,61 +212,29 @@ void FusedDotProductAttentionKernel(
83212
const phi::DenseTensor &q,
84213
const phi::DenseTensor &k,
85214
const phi::DenseTensor &v,
86-
const phi::DenseTensor &mask,
87-
// const paddle::optional<phi::DenseTensor> &attention_mask,
88-
// const paddle::optional<phi::DenseTensor> &cu_seqlen_q,
89-
// const paddle::optional<phi::DenseTensor> &cu_seqlen_kv,
215+
const paddle::optional<phi::DenseTensor> &attention_mask,
216+
const paddle::optional<phi::DenseTensor> &cu_seqlen_q,
217+
const paddle::optional<phi::DenseTensor> &cu_seqlen_kv,
90218
float scaling_factor,
91219
float dropout_probability,
92220
bool is_training,
93-
bool is_causal_masking,
94-
// const std::string &mask_type_str,
95-
// const std::string &bias_type_str,
221+
const std::string &mask_type_str,
222+
const std::string &bias_type_str,
96223
phi::DenseTensor *out,
97224
phi::DenseTensor *softmax_out,
98225
phi::DenseTensor *rng_state) {
99-
std::vector<int> axis = {0, 2, 1, 3};
100-
phi::DenseTensor qt;
101-
// auto q_dims = q.dims();
102-
std::vector<int64_t> q_dims = phi::vectorize<int64_t>(q.dims());
103-
std::vector<int64_t> qt_dims(q_dims.cbegin(), q_dims.cend());
104-
105-
int rank = q_dims.size();
106-
qt_dims[rank - 3] = q_dims[rank - 2];
107-
qt_dims[rank - 2] = q_dims[rank - 3];
108-
109-
phi::DenseTensorMeta qt_meta({q.dtype(), phi::make_ddim(qt_dims)});
110-
qt.set_meta(qt_meta);
111-
custom_kernel::TransposeKernel<T, Context>(dev_ctx, q, axis, &qt);
112-
113-
phi::DenseTensor kt;
114-
phi::DenseTensor vt;
115-
std::vector<int64_t> kv_dims = phi::vectorize<int64_t>(k.dims());
116-
std::vector<int64_t> kvt_dims(kv_dims.cbegin(), kv_dims.cend());
117-
kvt_dims[rank - 3] = kv_dims[rank - 2];
118-
kvt_dims[rank - 2] = kv_dims[rank - 3];
119-
phi::DenseTensorMeta kvt_meta({k.dtype(), phi::make_ddim(kvt_dims)});
120-
kt.set_meta(kvt_meta);
121-
vt.set_meta(kvt_meta);
122-
custom_kernel::TransposeKernel<T, Context>(dev_ctx, k, axis, &kt);
123-
custom_kernel::TransposeKernel<T, Context>(dev_ctx, v, axis, &vt);
124-
125-
out->Resize(phi::make_ddim(qt_dims));
126226
dev_ctx.template Alloc<T>(out);
127227
if (is_training) {
128228
dev_ctx.template Alloc<T>(softmax_out);
129229
}
130230

131231
ConvertTensors ct;
132-
ct.Add(qt);
133-
ct.Add(kt);
134-
ct.Add(vt);
135-
ct.Add(mask);
136-
/*
232+
ct.Add(q);
233+
ct.Add(k);
234+
ct.Add(v);
137235
if (attention_mask.get_ptr()) {
138236
ct.Add(attention_mask.get_ptr());
139237
}
140-
*/
141238
ct.Add(out, false);
142239
if (is_training) {
143240
ct.Add(softmax_out, false);
@@ -149,8 +246,7 @@ void FusedDotProductAttentionKernel(
149246
ns_Sdpa::ParamsV2 params;
150247
memset(reinterpret_cast<void *>(&params), 0x00, sizeof(ns_Sdpa::ParamsV2));
151248
params.scale = scaling_factor;
152-
params.is_causal = is_causal_masking;
153-
// params.is_causal = (mask_type_str == "causal");
249+
params.is_causal = (mask_type_str == "causal");
154250
params.dropout.ratio = dropout_probability;
155251
params.dropout.disableMaskOut = false;
156252
params.is_inference = !is_training;
@@ -163,7 +259,7 @@ void FusedDotProductAttentionKernel(
163259
auto recipe = op_info.GetRecipe();
164260

165261
if (recipe == nullptr) {
166-
FSDPA op(op_info.guid_);
262+
FSDPA op(op_info.guid_, op_info.datatype_);
167263

168264
op.AddNode(ct, params);
169265
op.Compile();

0 commit comments

Comments
 (0)