Skip to content

Commit bcd77bc

Browse files
authored
[INTEL_HPU] add support of scatter and scatter_ (#1499)
1 parent 8b1b87b commit bcd77bc

File tree

2 files changed

+484
-0
lines changed

2 files changed

+484
-0
lines changed
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "habanalabs/synapse_api.h"
17+
#include "habanalabs/synapse_common_types.h"
18+
#include "kernels/funcs.h"
19+
#include "kernels/hpu_operator.h"
20+
#include "utils/utils.h"
21+
22+
namespace custom_kernel {
23+
24+
template <typename T, typename Context>
25+
void ExpandKernel(const Context& dev_ctx,
26+
const phi::DenseTensor& x,
27+
const phi::IntArray& shape,
28+
phi::DenseTensor* out);
29+
30+
template <typename T, typename Context>
31+
void CastKernel(const Context& dev_ctx,
32+
const phi::DenseTensor& x,
33+
phi::DataType dtype,
34+
phi::DenseTensor* out);
35+
36+
template <typename T, typename Context>
37+
void FullKernel(const Context& dev_ctx,
38+
const phi::IntArray& shape,
39+
const phi::Scalar& val,
40+
phi::DataType dtype,
41+
phi::DenseTensor* out);
42+
43+
template <typename T, typename Context>
44+
void FullLikeKernel(const Context& dev_ctx,
45+
const phi::DenseTensor& x,
46+
const phi::Scalar& val,
47+
phi::DataType dtype,
48+
phi::DenseTensor* out);
49+
50+
struct ScatterParams {
51+
ns_ScatterKernel::Params params;
52+
};
53+
54+
class Scatter : public HpuOperator {
55+
public:
56+
Scatter() : HpuOperator("scatter_fwd_") {}
57+
58+
void AddNode(ConvertTensors& ct, ScatterParams params, bool is_inplace) {
59+
auto inputs = ct.GetTensors();
60+
auto outputs = ct.GetTensors(false);
61+
62+
std::vector<synTensor> syn_inputs;
63+
synSectionHandle section_shared = nullptr;
64+
for (size_t i = 0; i < inputs.size(); i++) {
65+
if (i == 0 && is_inplace) {
66+
section_shared = createSection();
67+
syn_inputs.push_back(createTensor(inputs[i].dims.size(),
68+
inputs[i].type,
69+
inputs[i].dims,
70+
true,
71+
inputs[i].name,
72+
section_shared));
73+
} else {
74+
syn_inputs.push_back(createTensor(inputs[i].dims.size(),
75+
inputs[i].type,
76+
inputs[i].dims,
77+
true,
78+
inputs[i].name));
79+
}
80+
}
81+
82+
std::vector<synTensor> syn_outputs;
83+
syn_outputs.push_back(createTensor(outputs[0].dims.size(),
84+
outputs[0].type,
85+
outputs[0].dims,
86+
true,
87+
outputs[0].name,
88+
section_shared));
89+
90+
guid_ = guid_ + SynDataTypeToStr(inputs[0].type);
91+
synStatus status = synNodeCreate(graphHandle_,
92+
syn_inputs.data(),
93+
syn_outputs.data(),
94+
syn_inputs.size(),
95+
syn_outputs.size(),
96+
&params.params,
97+
sizeof(params.params),
98+
guid_.c_str(),
99+
"Scatter",
100+
nullptr,
101+
nullptr);
102+
PD_CHECK(
103+
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
104+
}
105+
};
106+
107+
class ScatterAdd : public HpuOperator {
108+
public:
109+
ScatterAdd() : HpuOperator("unsorted_scatter_add_fwd_") {}
110+
111+
void AddNode(ConvertTensors& ct, ScatterParams params) {
112+
auto inputs = ct.GetTensors();
113+
auto outputs = ct.GetTensors(false);
114+
115+
std::vector<synTensor> syn_inputs;
116+
for (size_t i = 0; i < inputs.size(); i++) {
117+
syn_inputs.push_back(createTensor(inputs[i].dims.size(),
118+
inputs[i].type,
119+
inputs[i].dims,
120+
true,
121+
inputs[i].name));
122+
}
123+
124+
std::vector<synTensor> syn_outputs;
125+
for (size_t i = 0; i < outputs.size(); i++) {
126+
syn_outputs.push_back(createTensor(outputs[i].dims.size(),
127+
outputs[i].type,
128+
outputs[i].dims,
129+
true,
130+
outputs[i].name));
131+
}
132+
133+
guid_ = guid_ + SynDataTypeToStr(inputs[0].type);
134+
135+
synStatus status = synNodeCreate(graphHandle_,
136+
syn_inputs.data(),
137+
syn_outputs.data(),
138+
syn_inputs.size(),
139+
syn_outputs.size(),
140+
&params.params,
141+
sizeof(params.params),
142+
guid_.c_str(),
143+
"ScatterAdd",
144+
nullptr,
145+
nullptr);
146+
PD_CHECK(
147+
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
148+
}
149+
};
150+
151+
template <typename T, typename Context>
152+
void ScatterKernelOverwrite(const Context& dev_ctx,
153+
const phi::DenseTensor& x,
154+
const phi::DenseTensor& index,
155+
const phi::DenseTensor& update,
156+
phi::DenseTensor* out) {
157+
PD_CHECK(index.dtype() == phi::DataType::INT32 ||
158+
index.dtype() == phi::DataType::INT64,
159+
"Scatter requires the index type be either int32 or int64");
160+
161+
auto index_dims = phi::vectorize<int>(index.dims());
162+
auto update_dims = phi::vectorize<int>(update.dims());
163+
PD_CHECK(update_dims[0] == index_dims[0],
164+
"Scatter requires the 1st dim of update match the 1st dim of index");
165+
166+
if (index_dims.size() == 2) {
167+
PD_CHECK(index_dims[1] != 1,
168+
"Scatter's index 2nd dim must be 1 for 2D index");
169+
} else if (index_dims.size() == 1) {
170+
index_dims.push_back(1);
171+
} else {
172+
PADDLE_THROW(
173+
phi::errors::InvalidArgument("Scatter requires the index type "
174+
"be either int32 or int64."));
175+
}
176+
177+
phi::DenseTensor index_i32;
178+
phi::DenseTensor fake_index(index);
179+
phi::DenseTensor* expand_src = &fake_index;
180+
phi::DenseTensorMeta fake_meta({index.dtype(), {phi::make_ddim(index_dims)}});
181+
fake_index.set_meta(fake_meta);
182+
183+
if (index.dtype() == phi::DataType::INT64) {
184+
index_i32.Resize(phi::make_ddim(index_dims));
185+
dev_ctx.template Alloc<int32_t>(&index_i32);
186+
187+
custom_kernel::CastKernel<int64_t, Context>(
188+
dev_ctx, fake_index, phi::DataType::INT32, &index_i32);
189+
expand_src = &index_i32;
190+
}
191+
192+
phi::IntArray out_shape(update_dims);
193+
phi::DenseTensor index_expand;
194+
index_expand.Resize(phi::make_ddim(update_dims));
195+
dev_ctx.template Alloc<int32_t>(&index_expand);
196+
197+
custom_kernel::ExpandKernel<int32_t, Context>(
198+
dev_ctx, *expand_src, out_shape, &index_expand);
199+
200+
dev_ctx.template Alloc<T>(out);
201+
bool is_inplace = (out->data() == x.data());
202+
203+
ConvertTensors ct;
204+
ct.Add(x);
205+
ct.Add(index_expand);
206+
ct.Add(update);
207+
ct.Add(out, false);
208+
209+
OpCacheOperator op_info;
210+
ScatterParams params;
211+
params.params.axis = x.dims().size() - 1;
212+
std::vector<DIMS> inputs_dims = ct.GetDims();
213+
// need to add different nodes for inplace and non-inplace scatter
214+
if (is_inplace) {
215+
op_info.prepareOpInfo<T, ScatterParams>(
216+
"ScatterKernel_", inputs_dims, &params);
217+
} else {
218+
op_info.prepareOpInfo<T, ScatterParams>(
219+
"ScatterKernel", inputs_dims, &params);
220+
}
221+
auto recipe = op_info.GetRecipe();
222+
223+
if (recipe == nullptr) {
224+
Scatter op;
225+
op.AddNode(ct, params, is_inplace);
226+
op.Compile();
227+
op_info.setOp(op);
228+
229+
recipe = op_info.GetRecipe();
230+
}
231+
232+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
233+
RecipeRunner runner(recipe);
234+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
235+
}
236+
237+
template <typename T, typename Context>
238+
void ScatterKernelAdd(const Context& dev_ctx,
239+
const phi::DenseTensor& x,
240+
const phi::DenseTensor& index,
241+
const phi::DenseTensor& update,
242+
phi::DenseTensor* out) {
243+
PD_CHECK(index.dtype() == phi::DataType::INT32 ||
244+
index.dtype() == phi::DataType::INT64,
245+
"ScatterAdd requires the index type be either int32 or int64");
246+
247+
auto index_dims = phi::vectorize<int>(index.dims());
248+
auto update_dims = phi::vectorize<int>(update.dims());
249+
PD_CHECK(
250+
update_dims[0] == index_dims[0],
251+
"ScatterAdd requires the 1st dim of update match the 1st dim of index");
252+
253+
if (index_dims.size() == 2) {
254+
PD_CHECK(index_dims[1] != 1,
255+
"ScatterAdd's index 2nd dim must be 1 for 2D index");
256+
} else if (index_dims.size() == 1) {
257+
index_dims.push_back(1);
258+
} else {
259+
PADDLE_THROW(
260+
phi::errors::InvalidArgument("Scatter requires the index type "
261+
"be either int32 or int64."));
262+
}
263+
264+
phi::DenseTensor index_i32;
265+
phi::DenseTensor fake_index(index);
266+
phi::DenseTensor* expand_src = &fake_index;
267+
phi::DenseTensorMeta fake_meta({index.dtype(), {phi::make_ddim(index_dims)}});
268+
fake_index.set_meta(fake_meta);
269+
270+
if (index.dtype() == phi::DataType::INT64) {
271+
index_i32.Resize(phi::make_ddim(index_dims));
272+
dev_ctx.template Alloc<int32_t>(&index_i32);
273+
274+
custom_kernel::CastKernel<int64_t, Context>(
275+
dev_ctx, fake_index, phi::DataType::INT32, &index_i32);
276+
expand_src = &index_i32;
277+
}
278+
279+
phi::IntArray out_shape(update_dims);
280+
phi::DenseTensor index_expand;
281+
index_expand.Resize(phi::make_ddim(update_dims));
282+
dev_ctx.template Alloc<int32_t>(&index_expand);
283+
284+
custom_kernel::ExpandKernel<int32_t, Context>(
285+
dev_ctx, *expand_src, out_shape, &index_expand);
286+
287+
dev_ctx.template Alloc<T>(out);
288+
289+
ConvertTensors ct;
290+
ct.Add(x);
291+
ct.Add(index_expand);
292+
ct.Add(update);
293+
ct.Add(out, false);
294+
295+
OpCacheOperator op_info;
296+
ScatterParams params;
297+
params.params.axis = x.dims().size() - 1;
298+
std::vector<DIMS> inputs_dims = ct.GetDims();
299+
op_info.prepareOpInfo<T, ScatterParams>(
300+
"ScatterAddKernel", inputs_dims, &params);
301+
auto recipe = op_info.GetRecipe();
302+
303+
if (recipe == nullptr) {
304+
ScatterAdd op;
305+
306+
op.AddNode(ct, params);
307+
308+
op.Compile();
309+
310+
op_info.setOp(op);
311+
312+
recipe = op_info.GetRecipe();
313+
}
314+
315+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
316+
RecipeRunner runner(recipe);
317+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
318+
}
319+
320+
template <typename T, typename Context>
321+
void ScatterKernel(const Context& dev_ctx,
322+
const phi::DenseTensor& x,
323+
const phi::DenseTensor& index,
324+
const phi::DenseTensor& update,
325+
bool overwrite,
326+
phi::DenseTensor* out) {
327+
if (overwrite) {
328+
ScatterKernelOverwrite<T, Context>(dev_ctx, x, index, update, out);
329+
} else {
330+
auto value = static_cast<T>(0);
331+
332+
phi::DenseTensor zero;
333+
phi::DenseTensorMeta zero_meta = {update.dtype(), update.dims()};
334+
zero.set_meta(zero_meta);
335+
custom_kernel::FullLikeKernel<T, Context>(
336+
dev_ctx, update, phi::Scalar(value), zero.dtype(), &zero);
337+
338+
phi::DenseTensor x1;
339+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &x1);
340+
341+
phi::DenseTensor x2;
342+
phi::DenseTensorMeta x2_meta = {x.dtype(), x.dims()};
343+
x2.set_meta(x2_meta);
344+
ScatterKernelOverwrite<T, Context>(dev_ctx, x1, index, zero, &x2);
345+
346+
ScatterKernelAdd<T, Context>(dev_ctx, x2, index, update, out);
347+
}
348+
}
349+
350+
} // namespace custom_kernel
351+
352+
PD_REGISTER_PLUGIN_KERNEL(scatter,
353+
intel_hpu,
354+
ALL_LAYOUT,
355+
custom_kernel::ScatterKernel,
356+
float,
357+
phi::dtype::float16,
358+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)