Skip to content

Commit b4088cc

Browse files
[cherry-pick-28] add linspace op,test=develop (#5601) (#5717)
1 parent 30c5ab3 commit b4088cc

File tree

7 files changed

+272
-0
lines changed

7 files changed

+272
-0
lines changed

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ add_kernel(tile_compute_host Host extra SRCS tile_compute.cc DEPS ${lite_kernel_
6262
add_kernel(topk_v2_compute_host Host extra SRCS topk_v2_compute.cc DEPS ${lite_kernel_deps})
6363
add_kernel(fill_any_like_compute_host Host extra SRCS fill_any_like_compute.cc DEPS ${lite_kernel_deps})
6464
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
65+
add_kernel(linspace_compute_host Host extra SRCS linspace_compute.cc DEPS ${lite_kernel_deps})
6566

6667
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
6768
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)

lite/kernels/host/linspace_compute.cc

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/linspace_compute.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace kernels {
21+
namespace host {
22+
23+
template <typename Tin, typename Tout>
24+
static void LinspaceFunc(const operators::LinspaceParam& param) {
25+
const auto* start_tensor = param.Start;
26+
const auto* stop_tensor = param.Stop;
27+
const auto* num_tensor = param.Num;
28+
auto* out_tensor = param.Out;
29+
const Tout start = static_cast<Tout>(start_tensor->template data<Tin>()[0]);
30+
const Tout stop = static_cast<Tout>(stop_tensor->template data<Tin>()[0]);
31+
const int num = num_tensor->data<int>()[0];
32+
Tout* out_data = out_tensor->template mutable_data<Tout>();
33+
34+
if (num > 1) {
35+
// step should be of double type for all types
36+
double step = (static_cast<double>(stop - start)) / (num - 1);
37+
int half_num = num / 2;
38+
for (int i = 0; i < num; ++i) {
39+
if (i < half_num) {
40+
out_data[i] = static_cast<Tout>(start + step * i);
41+
} else {
42+
out_data[i] = static_cast<Tout>(stop - step * (num - i - 1));
43+
}
44+
}
45+
} else {
46+
out_data[0] = static_cast<Tout>(start);
47+
}
48+
}
49+
50+
template <typename T, PrecisionType PType>
51+
void LinspaceCompute<T, PType>::Run() {
52+
auto& param = this->template Param<operators::LinspaceParam>();
53+
switch (param.Out->precision()) {
54+
case PRECISION(kFloat):
55+
LinspaceFunc<T, float>(param);
56+
break;
57+
case PRECISION(kInt32):
58+
LinspaceFunc<T, int32_t>(param);
59+
break;
60+
default:
61+
LOG(FATAL) << "Linspace op unsupport output data type: "
62+
<< lite_api::PrecisionToStr(param.Out->precision());
63+
}
64+
return;
65+
}
66+
} // namespace host
67+
} // namespace kernels
68+
} // namespace lite
69+
} // namespace paddle
70+
71+
using linspace_float =
72+
paddle::lite::kernels::host::LinspaceCompute<float, PRECISION(kFloat)>;
73+
REGISTER_LITE_KERNEL(linspace, kHost, kFloat, kAny, linspace_float, float32)
74+
.BindInput("Start",
75+
{LiteType::GetTensorTy(TARGET(kHost),
76+
PRECISION(kFloat),
77+
DATALAYOUT(kAny))})
78+
.BindInput("Stop",
79+
{LiteType::GetTensorTy(TARGET(kHost),
80+
PRECISION(kFloat),
81+
DATALAYOUT(kAny))})
82+
.BindInput("Num",
83+
{LiteType::GetTensorTy(TARGET(kHost),
84+
PRECISION(kInt32),
85+
DATALAYOUT(kAny))})
86+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
87+
.Finalize();
88+
89+
using linspace_int32 =
90+
paddle::lite::kernels::host::LinspaceCompute<int, PRECISION(kInt32)>;
91+
REGISTER_LITE_KERNEL(linspace, kHost, kInt32, kAny, linspace_int32, int32)
92+
.BindInput("Start",
93+
{LiteType::GetTensorTy(TARGET(kHost),
94+
PRECISION(kInt32),
95+
DATALAYOUT(kAny))})
96+
.BindInput("Stop",
97+
{LiteType::GetTensorTy(TARGET(kHost),
98+
PRECISION(kInt32),
99+
DATALAYOUT(kAny))})
100+
.BindInput("Num",
101+
{LiteType::GetTensorTy(TARGET(kHost),
102+
PRECISION(kInt32),
103+
DATALAYOUT(kAny))})
104+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
105+
.Finalize();

lite/kernels/host/linspace_compute.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace host {
23+
24+
template <typename T, PrecisionType PType>
25+
class LinspaceCompute
26+
: public KernelLite<TARGET(kHost), PType, DATALAYOUT(kAny)> {
27+
public:
28+
void Run() override;
29+
30+
virtual ~LinspaceCompute() = default;
31+
};
32+
33+
} // namespace host
34+
} // namespace kernels
35+
} // namespace lite
36+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ add_operator(expand_v2_op_lite extra SRCS expand_v2_op.cc DEPS ${op_DEPS})
145145
add_operator(tile_op extra SRCS tile_op.cc DEPS ${op_DEPS})
146146
add_operator(sum_op extra SRCS sum_op.cc DEPS ${op_DEPS})
147147
add_operator(meshgrid_op_lite extra SRCS meshgrid_op.cc DEPS ${op_DEPS})
148+
add_operator(linspace_op extra SRCS linspace_op.cc DEPS ${op_DEPS})
148149

149150
# for OCR specific
150151
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})

lite/operators/linspace_op.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/linspace_op.h"
16+
#include "lite/core/op_registry.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace operators {
21+
22+
bool LinspaceOpLite::CheckShape() const {
23+
CHECK_OR_FALSE(param_.Start);
24+
CHECK_OR_FALSE(param_.Stop);
25+
CHECK_OR_FALSE(param_.Num);
26+
CHECK_OR_FALSE(param_.Out);
27+
28+
int start_dims_size = param_.Start->dims().size();
29+
CHECK_EQ(start_dims_size, 1) << "The shape of input start must be 1.";
30+
int stop_dims_size = param_.Stop->dims().size();
31+
CHECK_EQ(stop_dims_size, 1) << "The shape of input stop must be 1.";
32+
int num_dims_size = param_.Num->dims().size();
33+
CHECK_EQ(num_dims_size, 1) << "The shape of input num must be 1.";
34+
35+
return true;
36+
}
37+
38+
bool LinspaceOpLite::InferShapeImpl() const {
39+
// param_.dtype(int) is defined in paddle/fluid/framework/framework.proto
40+
// param_.dtype(int) means output dtype and lite supports fp32/int32.
41+
// if param_.dtype is not defined, output dtype is fp32.
42+
switch (param_.dtype) {
43+
case 2:
44+
param_.Out->set_precision(PRECISION(kInt32));
45+
break;
46+
case 5:
47+
param_.Out->set_precision(PRECISION(kFloat));
48+
break;
49+
default:
50+
param_.Out->set_precision(PRECISION(kFloat));
51+
break;
52+
}
53+
param_.Out->Resize(param_.Num->dims());
54+
return true;
55+
}
56+
57+
bool LinspaceOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
58+
auto start_name = opdesc.Input("Start").front();
59+
auto stop_name = opdesc.Input("Stop").front();
60+
auto num_name = opdesc.Input("Num").front();
61+
auto Out_name = opdesc.Output("Out").front();
62+
param_.Start = GetVar<lite::Tensor>(scope, start_name);
63+
param_.Stop = GetVar<lite::Tensor>(scope, stop_name);
64+
param_.Num = GetVar<lite::Tensor>(scope, num_name);
65+
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
66+
67+
if (opdesc.HasAttr("dtype")) {
68+
param_.dtype = opdesc.GetAttr<int>("dtype");
69+
}
70+
return true;
71+
}
72+
73+
} // namespace operators
74+
} // namespace lite
75+
} // namespace paddle
76+
77+
REGISTER_LITE_OP(linspace, paddle::lite::operators::LinspaceOpLite);

lite/operators/linspace_op.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include <vector>
18+
#include "lite/core/op_lite.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace operators {
23+
24+
class LinspaceOpLite : public OpLite {
25+
public:
26+
LinspaceOpLite() {}
27+
explicit LinspaceOpLite(const std::string &op_type) : OpLite(op_type) {}
28+
29+
bool CheckShape() const override;
30+
31+
bool InferShapeImpl() const override;
32+
33+
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
34+
35+
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
36+
std::string DebugString() const override { return "linspace"; }
37+
38+
private:
39+
mutable LinspaceParam param_;
40+
};
41+
42+
} // namespace operators
43+
} // namespace lite
44+
} // namespace paddle

lite/operators/op_params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,6 +2254,14 @@ struct PNormParam : ParamBase {
22542254
bool keepdim{false};
22552255
bool asvector{false};
22562256
};
2257+
2258+
struct LinspaceParam : ParamBase {
2259+
const lite::Tensor* Start{};
2260+
const lite::Tensor* Stop{};
2261+
const lite::Tensor* Num{};
2262+
lite::Tensor* Out{};
2263+
int dtype{};
2264+
};
22572265
} // namespace operators
22582266
} // namespace lite
22592267
} // namespace paddle

0 commit comments

Comments
 (0)