Skip to content

Commit f15d7d1

Browse files
authored
[INTEL_HPU] Add in_place mark in OpCacheOperator (#1613)
1 parent acb0ff7 commit f15d7d1

File tree

7 files changed

+154
-160
lines changed

7 files changed

+154
-160
lines changed

backends/intel_hpu/kernels/cum_kernel.cc

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ namespace custom_kernel {
2525

2626
class CumsumOperator : public HpuOperator {
2727
public:
28-
CumsumOperator(std::string guid_prefix, std::string node_name)
29-
: HpuOperator(guid_prefix), pName_(node_name) {}
30-
void AddNode(ConvertTensors& ct, ns_CumSumKernel::Params params) {
28+
CumsumOperator() : HpuOperator("cumsum_fwd_") {}
29+
void AddNode(ConvertTensors& ct,
30+
ns_CumSumKernel::Params params,
31+
bool in_place = false) {
3132
auto inputs = ct.GetTensors();
3233
auto outputs = ct.GetTensors(false);
3334

34-
synSectionHandle section = nullptr;
35-
if (inputs[0].device_addr == outputs[0].device_addr) {
36-
section = createSection();
37-
}
35+
synSectionHandle section = in_place ? createSection() : nullptr;
3836

3937
std::vector<synTensor> syn_inputs;
4038
for (size_t i = 0; i < inputs.size(); i++) {
@@ -56,23 +54,21 @@ class CumsumOperator : public HpuOperator {
5654
section));
5755
}
5856

59-
std::string guid = +"cumsum_fwd_" + SynDataTypeToStr(inputs[0].type);
60-
57+
guid_ = guid_ + SynDataTypeToStr(inputs[0].type);
6158
synStatus status = synNodeCreate(graphHandle_,
6259
syn_inputs.data(),
6360
syn_outputs.data(),
6461
syn_inputs.size(),
6562
syn_outputs.size(),
6663
&params,
6764
sizeof(params),
68-
guid.c_str(),
69-
pName_.c_str(),
65+
guid_.c_str(),
66+
"cumsum",
7067
nullptr,
7168
nullptr);
7269
PD_CHECK(
7370
status == synSuccess, "[RUNTIME] synNodeCreate () failed = %d", status);
7471
}
75-
std::string pName_;
7672
};
7773

7874
template <typename T, typename Context>
@@ -107,22 +103,23 @@ void CumsumKernel(const Context& dev_ctx,
107103
std::vector<int64_t> inputs_dim =
108104
phi::vectorize<int64_t>(input_tensor.dims());
109105
ct.Add(input_tensor);
106+
ct.Add(out, false);
107+
110108
int params_exclusive = static_cast<int>(exclusive);
111109
int params_reverse = static_cast<int>(reverse);
112110
ns_CumSumKernel::Params params{params_axis, params_exclusive, params_reverse};
113111

112+
bool in_place = (input_tensor.data() == out->data());
113+
114114
OpCacheOperator op_info;
115115
op_info.prepareOpInfo<T, ns_CumSumKernel::Params>(
116-
"cumsum_fwd_", {inputs_dim}, &params);
116+
in_place ? "CumSumKernel_" : "CumSumKernel", {inputs_dim}, &params);
117117

118118
auto recipe = op_info.GetRecipe();
119-
ct.Add(out, false);
120119
if (recipe == nullptr) {
121120
// compile
122-
std::string op_node_name =
123-
(input_tensor.data() == out->data()) ? "_cumsum_op" : "cumsum_op";
124-
CumsumOperator op(op_info.guid_, op_node_name);
125-
op.AddNode(ct, params);
121+
CumsumOperator op;
122+
op.AddNode(ct, params, in_place);
126123
op.Compile();
127124
op_info.setOp(op);
128125
recipe = op_info.GetRecipe();

backends/intel_hpu/kernels/elementwise_kernel.cc

Lines changed: 53 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,24 @@ void FullKernel(const Context& dev_ctx,
3030

3131
class BinaryOperator : public HpuOperator {
3232
public:
33-
BinaryOperator(std::string guid_prefix,
34-
std::string node_name,
35-
bool in_place = false)
36-
: HpuOperator(guid_prefix), pName_(node_name) {
37-
inPlace_ = in_place;
38-
}
33+
explicit BinaryOperator(std::string guid_prefix) : HpuOperator(guid_prefix) {}
3934

4035
void AddNode(const std::vector<DIMS>& ins,
4136
const std::vector<DIMS>& outs,
42-
synDataType datatype) {
37+
synDataType datatype,
38+
bool in_place = false) {
4339
assert(ins.size() == 2 && "input size should be 2");
4440
assert(outs.size() == 1 && "output size should be 1");
4541

46-
synSectionHandle section = nullptr;
47-
if (inPlace_) {
48-
section = createSection();
49-
}
42+
synSectionHandle section = in_place ? createSection() : nullptr;
5043

5144
synTensor inputs[ins.size()] = {
5245
createTensor(ins[0].size(), datatype, ins[0], true, "x", section),
5346
createTensor(ins[1].size(), datatype, ins[1], true, "y")};
5447
synTensor outputs[outs.size()] = {createTensor(
5548
outs[0].size(), datatype, outs[0], true, "output", section)};
49+
50+
guid_ = guid_ + SynDataTypeToStr(datatype);
5651
synStatus status = synNodeCreate(graphHandle_,
5752
inputs,
5853
outputs,
@@ -61,62 +56,61 @@ class BinaryOperator : public HpuOperator {
6156
nullptr,
6257
0,
6358
guid_.c_str(),
64-
pName_.c_str(),
59+
"bianary",
6560
nullptr,
6661
nullptr);
6762
PD_CHECK(status == synSuccess,
6863
"[RUNTIME] synNodeCreate binary fwd () failed = %d",
6964
status);
7065
}
71-
std::string pName_;
72-
bool inPlace_;
7366
};
7467

75-
#define BINARY_RAW_KERNEL(kernel_func, node_name) \
76-
template <typename T, typename Context> \
77-
void kernel_func##RawKernel(const Context& dev_ctx, \
78-
const phi::DenseTensor& x, \
79-
const phi::DenseTensor& y, \
80-
int axis, \
81-
phi::DenseTensor* out) { \
82-
dev_ctx.template Alloc<T>(out); \
83-
VLOG(6) << "CALL HPU " << #kernel_func << "RawKernel"; \
84-
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims()); \
85-
std::vector<int64_t> y_dim = phi::vectorize<int64_t>(y.dims()); \
86-
if (y_dim.size() == 0) { \
87-
y_dim.push_back(1); \
88-
} \
89-
if (x_dim.size() == 0) { \
90-
x_dim.push_back(1); \
91-
} \
92-
bool in_place = (x.data() == out->data()); \
93-
std::vector<int64_t> outputs_dim = phi::vectorize<int64_t>(out->dims()); \
94-
if (outputs_dim.size() == 0) { \
95-
outputs_dim.push_back(1); \
96-
} \
97-
OpCacheOperator op_info; \
98-
op_info.prepareOpInfo<T, nullptr_t>( \
99-
#node_name "_fwd", {x_dim, y_dim}, nullptr); \
100-
auto recipe = op_info.GetRecipe(); \
101-
\
102-
if (recipe == nullptr) { \
103-
std::string op_node_name = in_place ? "_" #node_name : #node_name; \
104-
BinaryOperator op(op_info.guid_, op_node_name, in_place); \
105-
op.AddNode({x_dim, y_dim}, {outputs_dim}, op_info.datatype_); \
106-
op.Compile(); \
107-
op_info.setOp(op); \
108-
recipe = op_info.GetRecipe(); \
109-
} \
110-
\
111-
std::map<std::string, uint64_t> tensors; \
112-
tensors["x"] = reinterpret_cast<uint64_t>(x.data<T>()); \
113-
tensors["y"] = reinterpret_cast<uint64_t>(y.data<T>()); \
114-
tensors["output"] = reinterpret_cast<uint64_t>(out->data<T>()); \
115-
\
116-
RecipeRunner runner(recipe); \
117-
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors); \
118-
\
119-
return; \
68+
#define BINARY_RAW_KERNEL(kernel_func, node_name) \
69+
template <typename T, typename Context> \
70+
void kernel_func##RawKernel(const Context& dev_ctx, \
71+
const phi::DenseTensor& x, \
72+
const phi::DenseTensor& y, \
73+
int axis, \
74+
phi::DenseTensor* out) { \
75+
dev_ctx.template Alloc<T>(out); \
76+
VLOG(6) << "CALL HPU " << #kernel_func << "RawKernel"; \
77+
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims()); \
78+
std::vector<int64_t> y_dim = phi::vectorize<int64_t>(y.dims()); \
79+
if (y_dim.size() == 0) { \
80+
y_dim.push_back(1); \
81+
} \
82+
if (x_dim.size() == 0) { \
83+
x_dim.push_back(1); \
84+
} \
85+
bool in_place = (x.data() == out->data()); \
86+
std::vector<int64_t> outputs_dim = phi::vectorize<int64_t>(out->dims()); \
87+
if (outputs_dim.size() == 0) { \
88+
outputs_dim.push_back(1); \
89+
} \
90+
OpCacheOperator op_info; \
91+
op_info.prepareOpInfo<T, nullptr_t>( \
92+
in_place ? (std::string(#node_name) + "_") : std::string(#node_name), \
93+
{x_dim, y_dim}, \
94+
nullptr); \
95+
auto recipe = op_info.GetRecipe(); \
96+
\
97+
if (recipe == nullptr) { \
98+
BinaryOperator op(std::string(#node_name) + "_"); \
99+
op.AddNode({x_dim, y_dim}, {outputs_dim}, op_info.datatype_, in_place); \
100+
op.Compile(); \
101+
op_info.setOp(op); \
102+
recipe = op_info.GetRecipe(); \
103+
} \
104+
\
105+
std::map<std::string, uint64_t> tensors; \
106+
tensors["x"] = reinterpret_cast<uint64_t>(x.data<T>()); \
107+
tensors["y"] = reinterpret_cast<uint64_t>(y.data<T>()); \
108+
tensors["output"] = reinterpret_cast<uint64_t>(out->data<T>()); \
109+
\
110+
RecipeRunner runner(recipe); \
111+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors); \
112+
\
113+
return; \
120114
}
121115

122116
#define BINARY_KERNEL(kernel_func) \

backends/intel_hpu/kernels/logical_kernel.cc

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ struct LogicalParams {
2323
class Logical : public HpuOperator {
2424
public:
2525
Logical() : HpuOperator("logical") {}
26-
void AddNode(ConvertTensors& ct, LogicalParams& params) {
26+
void AddNode(ConvertTensors& ct,
27+
LogicalParams& params,
28+
bool in_place = false) {
2729
auto inputs = ct.GetTensors();
2830
auto outputs = ct.GetTensors(false);
2931

30-
synSectionHandle section = nullptr;
31-
if (inputs[0].device_addr == outputs[0].device_addr) {
32-
section = createSection();
33-
}
32+
synSectionHandle section = in_place ? createSection() : nullptr;
3433

3534
std::vector<synTensor> syn_inputs;
3635
for (size_t i = 0; i < inputs.size(); i++) {
@@ -110,7 +109,6 @@ void LogicalOrKernel(const Context& dev_ctx,
110109
const phi::DenseTensor& y,
111110
phi::DenseTensor* out) {
112111
dev_ctx.template Alloc<bool>(out);
113-
OpCacheOperator op_info;
114112
ConvertTensors ct;
115113
ct.Add(x);
116114
ct.Add(y);
@@ -120,15 +118,18 @@ void LogicalOrKernel(const Context& dev_ctx,
120118
LogicalParams params = {};
121119
snprintf(params.op, MAX_OPNAME_LEN, "%s", "or");
122120
std::vector<DIMS> inputs_dims = ct.GetDims();
123-
std::string op_name =
124-
(x.data() == out->data()) ? "_LogicalOrKernel" : "LogicalOrKernel";
125-
op_info.prepareOpInfo<T, nullptr_t>(op_name, inputs_dims, nullptr);
121+
122+
bool in_place = (x.data() == out->data());
123+
124+
OpCacheOperator op_info;
125+
op_info.prepareOpInfo<T, nullptr_t>(
126+
in_place ? "LogicalOrKernel_" : "LogicalOrKernel", inputs_dims, nullptr);
126127
auto recipe = op_info.GetRecipe();
127128

128129
if (recipe == nullptr) {
129130
Logical op;
130131

131-
op.AddNode(ct, params);
132+
op.AddNode(ct, params, in_place);
132133
op.Compile();
133134
op_info.setOp(op);
134135

@@ -146,7 +147,6 @@ void LogicalAndKernel(const Context& dev_ctx,
146147
const phi::DenseTensor& y,
147148
phi::DenseTensor* out) {
148149
dev_ctx.template Alloc<bool>(out);
149-
OpCacheOperator op_info;
150150
ConvertTensors ct;
151151
ct.Add(x);
152152
ct.Add(y);
@@ -155,15 +155,20 @@ void LogicalAndKernel(const Context& dev_ctx,
155155
LogicalParams params = {};
156156
snprintf(params.op, MAX_OPNAME_LEN, "%s", "and");
157157
std::vector<DIMS> inputs_dims = ct.GetDims();
158-
std::string op_name =
159-
(x.data() == out->data()) ? "_LogicalAndKernel" : "LogicalAndKernel";
160-
op_info.prepareOpInfo<T, nullptr_t>(op_name, inputs_dims, nullptr);
158+
159+
bool in_place = (x.data() == out->data());
160+
161+
OpCacheOperator op_info;
162+
op_info.prepareOpInfo<T, nullptr_t>(
163+
in_place ? "LogicalAndKernel_" : "LogicalAndKernel",
164+
inputs_dims,
165+
nullptr);
161166
auto recipe = op_info.GetRecipe();
162167

163168
if (recipe == nullptr) {
164169
Logical op;
165170

166-
op.AddNode(ct, params);
171+
op.AddNode(ct, params, in_place);
167172
op.Compile();
168173
op_info.setOp(op);
169174

@@ -181,7 +186,6 @@ void LogicalXorKernel(const Context& dev_ctx,
181186
const phi::DenseTensor& y,
182187
phi::DenseTensor* out) {
183188
dev_ctx.template Alloc<bool>(out);
184-
OpCacheOperator op_info;
185189
ConvertTensors ct;
186190
ct.Add(x);
187191
ct.Add(y);
@@ -190,15 +194,20 @@ void LogicalXorKernel(const Context& dev_ctx,
190194
LogicalParams params = {};
191195
snprintf(params.op, MAX_OPNAME_LEN, "%s", "xor");
192196
std::vector<DIMS> inputs_dims = ct.GetDims();
193-
std::string op_name =
194-
(x.data() == out->data()) ? "_LogicalXorKernel" : "LogicalXorKernel";
195-
op_info.prepareOpInfo<T, nullptr_t>(op_name, inputs_dims, nullptr);
197+
198+
bool in_place = (x.data() == out->data());
199+
200+
OpCacheOperator op_info;
201+
op_info.prepareOpInfo<T, nullptr_t>(
202+
in_place ? "LogicalXorKernel_" : "LogicalXorKernel",
203+
inputs_dims,
204+
nullptr);
196205
auto recipe = op_info.GetRecipe();
197206

198207
if (recipe == nullptr) {
199208
Logical op;
200209

201-
op.AddNode(ct, params);
210+
op.AddNode(ct, params, in_place);
202211
op.Compile();
203212
op_info.setOp(op);
204213

0 commit comments

Comments
 (0)