Skip to content

Commit df14333

Browse files
[INTEL_HPU] support custom_op index_reduce_ and rename index_copy (#1617)
1 parent 59cf6ab commit df14333

File tree

5 files changed

+644
-3
lines changed

5 files changed

+644
-3
lines changed

backends/intel_hpu/custom_ops/src/index_copy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ std::vector<paddle::DataType> IndexCopyInferDtype(
167167
return {input_dtype};
168168
}
169169

170-
PD_BUILD_OP(index_copy)
170+
PD_BUILD_OP(index_copy_)
171171
.Inputs({"input", "index", "source"})
172172
.Outputs({"out"})
173173
.Attrs({"dim: int"})
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "kernels/funcs.h"
17+
#include "kernels/hpu_operator.h"
18+
#include "paddle/extension.h"
19+
#include "utils/utils.h"
20+
21+
namespace custom_kernel {
22+
23+
class IndexReduce : public HpuOperator {
24+
public:
25+
explicit IndexReduce(synDataType dtype)
26+
: HpuOperator("index_reduce_fwd"), dtype_(dtype) {}
27+
28+
void AddNode(ConvertTensors& ct, ns_IndexReduce::Params params) {
29+
auto inputs = ct.GetTensors();
30+
auto outputs = ct.GetTensors(false);
31+
32+
synSectionHandle section = createSection();
33+
34+
std::vector<synTensor> syn_inputs;
35+
syn_inputs.push_back(createTensor(inputs[0].dims.size(),
36+
inputs[0].type,
37+
inputs[0].dims,
38+
true,
39+
inputs[0].name,
40+
section));
41+
42+
syn_inputs.push_back(createTensor(inputs[1].dims.size(),
43+
inputs[1].type,
44+
inputs[1].dims,
45+
true,
46+
inputs[1].name));
47+
48+
syn_inputs.push_back(createTensor(inputs[2].dims.size(),
49+
inputs[2].type,
50+
inputs[2].dims,
51+
true,
52+
inputs[2].name));
53+
54+
std::vector<synTensor> syn_outputs;
55+
syn_outputs.push_back(createTensor(outputs[0].dims.size(),
56+
outputs[0].type,
57+
outputs[0].dims,
58+
true,
59+
outputs[0].name,
60+
section));
61+
62+
std::string guid = guid_ + "_" + SynDataTypeToStr(outputs[0].type);
63+
synStatus status = synNodeCreate(graphHandle_,
64+
syn_inputs.data(),
65+
syn_outputs.data(),
66+
syn_inputs.size(),
67+
syn_outputs.size(),
68+
&params,
69+
sizeof(params),
70+
guid.c_str(),
71+
"index_copy",
72+
nullptr,
73+
nullptr);
74+
75+
PD_CHECK(
76+
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
77+
}
78+
79+
protected:
80+
synDataType dtype_;
81+
};
82+
83+
template <typename T, typename Context>
84+
void IndexReduceKernel(const Context& dev_ctx,
85+
const phi::DenseTensor& input,
86+
const phi::Scalar& dim,
87+
const phi::DenseTensor& index,
88+
const phi::DenseTensor& source) {
89+
ConvertTensors ct;
90+
ct.Add(input);
91+
ct.Add(index);
92+
ct.Add(source);
93+
94+
ct.Add(input, false);
95+
96+
std::vector<DIMS> inputs_dims = ct.GetDims();
97+
ns_IndexReduce::Params params{};
98+
params.mode = INDEX_REDUCE_AMAX;
99+
params.include_self = true;
100+
params.axis = dim.to<unsigned>();
101+
102+
OpCacheOperator op_info;
103+
op_info.prepareOpInfo<T, ns_IndexReduce::Params>(
104+
"IndexReduceKernel_", inputs_dims, &params);
105+
106+
auto recipe = op_info.GetRecipe();
107+
108+
if (recipe == nullptr) {
109+
IndexReduce op(op_info.datatype_);
110+
op.AddNode(ct, params);
111+
op.Compile();
112+
op_info.setOp(op);
113+
recipe = op_info.GetRecipe();
114+
}
115+
116+
RecipeRunner runner(recipe);
117+
auto tensors = ct.GetDeviceAddr();
118+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
119+
}
120+
121+
} // namespace custom_kernel
122+
123+
template <typename Context>
124+
void CallIndexReduceKernel(const Context& dev_ctx,
125+
const phi::DenseTensor& input,
126+
const phi::Scalar& dim,
127+
const phi::DenseTensor& index,
128+
const phi::DenseTensor& source,
129+
const std::string reduce = "amax",
130+
const bool include_self = true) {
131+
if (input.dtype() == phi::DataType::FLOAT32) {
132+
custom_kernel::IndexReduceKernel<float>(dev_ctx, input, dim, index, source);
133+
} else if (input.dtype() == phi::DataType::INT32) {
134+
custom_kernel::IndexReduceKernel<int32_t>(
135+
dev_ctx, input, dim, index, source);
136+
} else if (input.dtype() == phi::DataType::BFLOAT16) {
137+
custom_kernel::IndexReduceKernel<phi::dtype::bfloat16>(
138+
dev_ctx, input, dim, index, source);
139+
} else {
140+
throw std::runtime_error("Unsupported data type for IndexReduceKernel");
141+
}
142+
}
143+
144+
void IndexReduceForward(const paddle::Tensor& input,
145+
const int dim,
146+
const paddle::Tensor& index,
147+
const paddle::Tensor& source,
148+
const std::string reduce = "amax",
149+
const bool include_self = true) {
150+
PD_CHECK(reduce == "amax", "only support reduce = amax");
151+
PD_CHECK(include_self == true, "only support include_self = true");
152+
auto dev_ctx = static_cast<const phi::CustomContext*>(
153+
paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
154+
155+
auto input_tensor = static_cast<phi::DenseTensor*>(input.impl().get());
156+
auto index_tensor = static_cast<const phi::DenseTensor*>(index.impl().get());
157+
auto source_tensor =
158+
static_cast<const phi::DenseTensor*>(source.impl().get());
159+
160+
CallIndexReduceKernel(
161+
*dev_ctx, *input_tensor, phi::Scalar(dim), *index_tensor, *source_tensor);
162+
}
163+
164+
std::vector<std::vector<int64_t>> IndexReduceInferShape(
165+
const std::vector<int64_t>& input_shape,
166+
const std::vector<int64_t>& index_shape,
167+
const std::vector<int64_t>& source_shape) {
168+
return {input_shape};
169+
}
170+
171+
std::vector<paddle::DataType> IndexReduceInferDtype(
172+
const paddle::DataType& input_dtype,
173+
const paddle::DataType& index_dtype,
174+
const paddle::DataType& source_dtype) {
175+
return {input_dtype};
176+
}
177+
178+
PD_BUILD_OP(index_reduce_)
179+
.Inputs({"input", "index", "source"})
180+
.Outputs({"out"})
181+
.Attrs({"dim: int", "reduce: std::string", "include_self: bool"})
182+
.SetInplaceMap({{"input", "out"}})
183+
.SetKernelFn(PD_KERNEL(IndexReduceForward))
184+
.SetInferShapeFn(PD_INFER_SHAPE(IndexReduceInferShape))
185+
.SetInferDtypeFn(PD_INFER_DTYPE(IndexReduceInferDtype));

0 commit comments

Comments
 (0)