Skip to content

Commit 30c5ab3

Browse files
[cherry-pick-28] add meshgrid host kernel,test=develop (#5590) (#5712)
1 parent 0c49dc6 commit 30c5ab3

File tree

7 files changed

+269
-1
lines changed

7 files changed

+269
-1
lines changed

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ add_kernel(strided_slice_compute_host Host extra SRCS strided_slice_compute.cc D
6161
add_kernel(tile_compute_host Host extra SRCS tile_compute.cc DEPS ${lite_kernel_deps})
6262
add_kernel(topk_v2_compute_host Host extra SRCS topk_v2_compute.cc DEPS ${lite_kernel_deps})
6363
add_kernel(fill_any_like_compute_host Host extra SRCS fill_any_like_compute.cc DEPS ${lite_kernel_deps})
64-
64+
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
6565

6666
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
6767
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)

lite/kernels/host/meshgrid_compute.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/meshgrid_compute.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace host {
22+
23+
template <typename T, PrecisionType PType>
24+
void MeshgridCompute<T, PType>::Run() {
25+
auto& param = this->template Param<operators::MeshgridParam>();
26+
std::vector<lite::Tensor*>& ins = param.X;
27+
std::vector<lite::Tensor*>& outs = param.Out;
28+
int64_t size = ins.size();
29+
std::vector<int64_t> shape(size);
30+
for (int64_t i = 0; i < size; ++i) {
31+
switch (ins[i]->dims().size()) {
32+
case 0:
33+
shape[i] = 1;
34+
break;
35+
case 1:
36+
shape[i] = ins[i]->dims()[0];
37+
break;
38+
default:
39+
LOG(FATAL) << "Meshgrid Op expected scalar or 1D tensor in the input "
40+
"tensor list";
41+
break;
42+
}
43+
}
44+
45+
DDim out_dims;
46+
out_dims.ConstructFrom(shape);
47+
48+
for (int64_t i = 0; i < size; ++i) {
49+
T* dst = outs[i]->template mutable_data<T>();
50+
outs[i]->Resize(out_dims);
51+
Tensor reshape_ins_tensor;
52+
reshape_ins_tensor.ShareDataWith(*ins[i]);
53+
std::vector<int64_t> view_shape(size, 1);
54+
view_shape[i] = shape[i];
55+
DDim in_dims_reshape;
56+
in_dims_reshape.ConstructFrom(view_shape);
57+
reshape_ins_tensor.Resize(in_dims_reshape);
58+
const T* src = reshape_ins_tensor.data<T>();
59+
std::vector<int> bcast_dims(size);
60+
for (int64_t j = 0; j < size; j++) {
61+
bcast_dims[j] = shape[j];
62+
}
63+
bcast_dims[i] = 1;
64+
int inner_num = 1;
65+
int idx = size - 1;
66+
int outer_num = in_dims_reshape.count(0, idx);
67+
inner_num *= in_dims_reshape[idx];
68+
for (int j = 0; j < outer_num; ++j) {
69+
for (int k = 0; k < bcast_dims[idx]; ++k) {
70+
memcpy(dst + (j * bcast_dims[idx] + k) * inner_num,
71+
src + j * inner_num,
72+
sizeof(T) * inner_num);
73+
}
74+
}
75+
inner_num *= bcast_dims[idx];
76+
for (int idx = size - 2; idx >= 0; --idx) {
77+
int outer_num = in_dims_reshape.count(0, idx);
78+
inner_num *= in_dims_reshape[idx];
79+
for (int j = outer_num - 1; j >= 0; --j) {
80+
for (int k = bcast_dims[idx] - 1; k >= 0; --k) {
81+
memcpy(dst + (j * bcast_dims[idx] + k) * inner_num,
82+
dst + j * inner_num,
83+
sizeof(T) * inner_num);
84+
}
85+
}
86+
inner_num *= bcast_dims[idx];
87+
}
88+
}
89+
}
90+
91+
} // namespace host
92+
} // namespace kernels
93+
} // namespace lite
94+
} // namespace paddle
95+
96+
using meshgrid_float =
97+
paddle::lite::kernels::host::MeshgridCompute<float, PRECISION(kFloat)>;
98+
REGISTER_LITE_KERNEL(meshgrid, kHost, kFloat, kAny, meshgrid_float, float32)
99+
.BindInput("X",
100+
{LiteType::GetTensorTy(TARGET(kHost),
101+
PRECISION(kFloat),
102+
DATALAYOUT(kAny))})
103+
.BindOutput("Out",
104+
{LiteType::GetTensorTy(TARGET(kHost),
105+
PRECISION(kFloat),
106+
DATALAYOUT(kAny))})
107+
.Finalize();
108+
109+
using meshgrid_int32 =
110+
paddle::lite::kernels::host::MeshgridCompute<int, PRECISION(kFloat)>;
111+
REGISTER_LITE_KERNEL(meshgrid, kHost, kFloat, kAny, meshgrid_int32, int32)
112+
.BindInput("X",
113+
{LiteType::GetTensorTy(TARGET(kHost),
114+
PRECISION(kInt32),
115+
DATALAYOUT(kAny))})
116+
.BindOutput("Out",
117+
{LiteType::GetTensorTy(TARGET(kHost),
118+
PRECISION(kInt32),
119+
DATALAYOUT(kAny))})
120+
.Finalize();

lite/kernels/host/meshgrid_compute.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace host {
23+
24+
template <typename T, PrecisionType PType>
25+
class MeshgridCompute
26+
: public KernelLite<TARGET(kHost), PType, DATALAYOUT(kAny)> {
27+
public:
28+
void Run() override;
29+
30+
virtual ~MeshgridCompute() = default;
31+
};
32+
33+
} // namespace host
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ add_operator(tensor_array_to_tensor_op extra SRCS tensor_array_to_tensor_op.cc D
144144
add_operator(expand_v2_op_lite extra SRCS expand_v2_op.cc DEPS ${op_DEPS})
145145
add_operator(tile_op extra SRCS tile_op.cc DEPS ${op_DEPS})
146146
add_operator(sum_op extra SRCS sum_op.cc DEPS ${op_DEPS})
147+
add_operator(meshgrid_op_lite extra SRCS meshgrid_op.cc DEPS ${op_DEPS})
147148

148149
# for OCR specific
149150
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})

lite/operators/meshgrid_op.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/meshgrid_op.h"
16+
#include "lite/core/op_registry.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace operators {
21+
22+
bool MeshgridOpLite::CheckShape() const {
23+
int x_size = param_.X.size();
24+
int out_size = param_.Out.size();
25+
CHECK_GE(x_size, 1) << "Input(X) should not be empty.";
26+
CHECK_GE(out_size, 1) << "Output(Out) should not be empty.";
27+
CHECK_LE(x_size, 6) << "The rank of Input(X) must not be greater than 6.";
28+
return true;
29+
}
30+
31+
bool MeshgridOpLite::InferShapeImpl() const {
32+
int inputs_num = param_.X.size();
33+
int outputs_num = param_.Out.size();
34+
std::vector<int64_t> out_shape(inputs_num);
35+
for (size_t i = 0; i < inputs_num; ++i) {
36+
out_shape[i] = param_.X[i]->dims()[0];
37+
}
38+
for (size_t i = 0; i < outputs_num; ++i) {
39+
param_.Out[i]->Resize(out_shape);
40+
}
41+
return true;
42+
}
43+
44+
bool MeshgridOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
45+
auto input_list = opdesc.Input("X");
46+
param_.X.clear();
47+
for (auto var : input_list) {
48+
param_.X.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
49+
}
50+
auto output_list = opdesc.Output("Out");
51+
param_.Out.clear();
52+
for (auto var : output_list) {
53+
param_.Out.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
54+
}
55+
return true;
56+
}
57+
58+
} // namespace operators
59+
} // namespace lite
60+
} // namespace paddle
61+
62+
REGISTER_LITE_OP(meshgrid, paddle::lite::operators::MeshgridOpLite);

lite/operators/meshgrid_op.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include <vector>
18+
#include "lite/core/op_lite.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace operators {
23+
24+
class MeshgridOpLite : public OpLite {
25+
public:
26+
MeshgridOpLite() {}
27+
explicit MeshgridOpLite(const std::string &op_type) : OpLite(op_type) {}
28+
29+
bool CheckShape() const override;
30+
31+
bool InferShapeImpl() const override;
32+
33+
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
34+
35+
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
36+
std::string DebugString() const override { return "meshgrid"; }
37+
38+
private:
39+
mutable MeshgridParam param_;
40+
};
41+
42+
} // namespace operators
43+
} // namespace lite
44+
} // namespace paddle

lite/operators/op_params.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,11 @@ struct SequenceConcatParam : ParamBase {
12611261
lite::Tensor* Out{};
12621262
};
12631263

1264+
struct MeshgridParam : ParamBase {
1265+
std::vector<lite::Tensor*> X{};
1266+
std::vector<lite::Tensor*> Out{};
1267+
};
1268+
12641269
struct AttentionPaddingMaskParam : ParamBase {
12651270
const lite::Tensor* X{};
12661271
const lite::Tensor* Y{};

0 commit comments

Comments
 (0)