Skip to content

Commit 6135491

Browse files
[HOST] add tril_triu; fix expand_as (#5507) (#5720)
Co-authored-by: zhupengyang <[email protected]>
1 parent 574883b commit 6135491

12 files changed

+424
-30
lines changed

lite/kernels/host/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ add_kernel(topk_v2_compute_host Host extra SRCS topk_v2_compute.cc DEPS ${lite_k
6363
add_kernel(fill_any_like_compute_host Host extra SRCS fill_any_like_compute.cc DEPS ${lite_kernel_deps})
6464
add_kernel(meshgrid_compute_host Host extra SRCS meshgrid_compute.cc DEPS ${lite_kernel_deps})
6565
add_kernel(linspace_compute_host Host extra SRCS linspace_compute.cc DEPS ${lite_kernel_deps})
66+
add_kernel(tril_triu_compute_host Host extra SRCS tril_triu_compute.cc DEPS ${lite_kernel_deps})
6667

6768
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
6869
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)

lite/kernels/host/expand_as_compute.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ void ExpandAsCompute<T, PType>::Run() {
3030
const T* src = x->template data<T>();
3131
T* dst = out->template mutable_data<T>();
3232

33-
// int dims = expand_times.size();
3433
for (int i = 0; i < target->dims().size(); ++i) {
3534
int times = target->dims()[i] / x->dims()[i];
3635
expand_times.push_back(times);
@@ -75,12 +74,29 @@ REGISTER_LITE_KERNEL(expand_as, kHost, kFloat, kAny, expand_as_float, def)
7574
{LiteType::GetTensorTy(TARGET(kHost),
7675
PRECISION(kFloat),
7776
DATALAYOUT(kAny))})
78-
.BindInput("Target",
77+
.BindInput("target_tensor",
7978
{LiteType::GetTensorTy(TARGET(kHost),
80-
PRECISION(kFloat),
79+
PRECISION(kAny),
8180
DATALAYOUT(kAny))})
8281
.BindOutput("Out",
8382
{LiteType::GetTensorTy(TARGET(kHost),
8483
PRECISION(kFloat),
8584
DATALAYOUT(kAny))})
8685
.Finalize();
86+
87+
using expand_as_int64 =
88+
paddle::lite::kernels::host::ExpandAsCompute<int64_t, PRECISION(kFloat)>;
89+
REGISTER_LITE_KERNEL(expand_as, kHost, kFloat, kAny, expand_as_int64, int64)
90+
.BindInput("X",
91+
{LiteType::GetTensorTy(TARGET(kHost),
92+
PRECISION(kInt64),
93+
DATALAYOUT(kAny))})
94+
.BindInput("target_tensor",
95+
{LiteType::GetTensorTy(TARGET(kHost),
96+
PRECISION(kAny),
97+
DATALAYOUT(kAny))})
98+
.BindOutput("Out",
99+
{LiteType::GetTensorTy(TARGET(kHost),
100+
PRECISION(kInt64),
101+
DATALAYOUT(kAny))})
102+
.Finalize();
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/host/tril_triu_compute.h"
16+
17+
namespace paddle {
18+
namespace lite {
19+
namespace kernels {
20+
namespace host {
21+
22+
template <class T>
23+
void TrilTriu(const T* in,
24+
const int64_t diagonal,
25+
const bool lower,
26+
const int64_t h,
27+
const int64_t w,
28+
T* out) {
29+
int64_t size = h * w;
30+
for (int64_t idx = 0; idx < size; idx++) {
31+
const int64_t row = idx / w;
32+
const int64_t col = idx % w;
33+
const bool mask = lower ? (col - row > diagonal) : (col - row < diagonal);
34+
out[idx] = mask ? 0 : in[idx];
35+
}
36+
return;
37+
}
38+
39+
template <class T>
40+
void TrilTriuCompute<T>::Run() {
41+
auto& param = this->template Param<param_t>();
42+
const lite::Tensor* x = param.x;
43+
lite::Tensor* out = param.out;
44+
int64_t diagonal = param.diagonal;
45+
bool lower = param.lower;
46+
47+
const T* x_data = x->template data<T>();
48+
T* out_data = out->template mutable_data<T>();
49+
auto x_dims = x->dims();
50+
int64_t h = x_dims[x_dims.size() - 2];
51+
int64_t w = x_dims[x_dims.size() - 1];
52+
int64_t n = x_dims.production() / h / w;
53+
54+
for (int64_t i = 0; i < n; i++) {
55+
TrilTriu(x_data, diagonal, lower, h, w, out_data);
56+
x_data += h * w;
57+
out_data += h * w;
58+
}
59+
return;
60+
}
61+
62+
} // namespace host
63+
} // namespace kernels
64+
} // namespace lite
65+
} // namespace paddle
66+
67+
using TrilTriuFloat32 = paddle::lite::kernels::host::TrilTriuCompute<float>;
68+
REGISTER_LITE_KERNEL(tril_triu, kHost, kAny, kNCHW, TrilTriuFloat32, float32)
69+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
70+
.BindOutput("Out",
71+
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kFloat))})
72+
.Finalize();

lite/kernels/host/tril_triu_compute.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "lite/core/kernel.h"
17+
#include "lite/core/op_registry.h"
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace kernels {
22+
namespace host {
23+
24+
template <class T>
25+
class TrilTriuCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
26+
public:
27+
using param_t = operators::TrilTriuParam;
28+
29+
void Run() override;
30+
31+
virtual ~TrilTriuCompute() = default;
32+
};
33+
34+
} // namespace host
35+
} // namespace kernels
36+
} // namespace lite
37+
} // namespace paddle

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ add_operator(select_input_op extra SRCS select_input_op.cc DEPS ${op_DEPS})
143143
add_operator(tensor_array_to_tensor_op extra SRCS tensor_array_to_tensor_op.cc DEPS ${op_DEPS})
144144
add_operator(expand_v2_op_lite extra SRCS expand_v2_op.cc DEPS ${op_DEPS})
145145
add_operator(tile_op extra SRCS tile_op.cc DEPS ${op_DEPS})
146-
add_operator(sum_op extra SRCS sum_op.cc DEPS ${op_DEPS})
147146
add_operator(meshgrid_op_lite extra SRCS meshgrid_op.cc DEPS ${op_DEPS})
148147
add_operator(linspace_op extra SRCS linspace_op.cc DEPS ${op_DEPS})
148+
add_operator(tril_triu_op extra SRCS tril_triu_op.cc DEPS ${op_DEPS})
149149

150150
# for OCR specific
151151
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})

lite/operators/expand_as_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ bool ExpandAsOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
4848
auto Out_name = opdesc.Output("Out").front();
4949
param_.X = GetVar<lite::Tensor>(scope, X_name);
5050
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
51-
auto Target_name = opdesc.Input("Target").front();
51+
auto Target_name = opdesc.Input("target_tensor").front();
5252
param_.Target = GetVar<lite::Tensor>(scope, Target_name);
5353
return true;
5454
}

lite/operators/op_params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,14 @@ struct TransposeParam : ParamBase {
646646
}
647647
};
648648

649+
struct TrilTriuParam : ParamBase {
650+
const lite::Tensor* x{nullptr};
651+
lite::Tensor* out{nullptr};
652+
653+
int diagonal{0};
654+
bool lower{true};
655+
};
656+
649657
/// ----------------------- element wise operators ----------------------
650658
struct ElementwiseParam : ParamBase {
651659
const lite::Tensor* X{};

lite/operators/tril_triu_op.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/tril_triu_op.h"
16+
#include "lite/core/op_registry.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace operators {
21+
22+
bool TrilTriuOp::CheckShape() const {
23+
CHECK(param_.x);
24+
CHECK(param_.out);
25+
return true;
26+
}
27+
28+
bool TrilTriuOp::InferShapeImpl() const {
29+
CHECK_GE(param_.x->dims().size(), 2UL);
30+
param_.out->Resize(param_.x->dims());
31+
param_.out->set_lod(param_.x->lod());
32+
return true;
33+
}
34+
35+
bool TrilTriuOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
36+
param_.x = scope->FindTensor(op_desc.Input("X").front());
37+
param_.out = scope->FindMutableTensor(op_desc.Output("Out").front());
38+
39+
param_.diagonal = op_desc.GetAttr<int>("diagonal");
40+
param_.lower = op_desc.GetAttr<bool>("lower");
41+
return true;
42+
}
43+
44+
} // namespace operators
45+
} // namespace lite
46+
} // namespace paddle
47+
48+
REGISTER_LITE_OP(tril_triu, paddle::lite::operators::TrilTriuOp);

lite/operators/tril_triu_op.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "lite/core/op_lite.h"
18+
#include "lite/core/scope.h"
19+
#include "lite/utils/all.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace operators {
24+
25+
class TrilTriuOp : public OpLite {
26+
public:
27+
TrilTriuOp() {}
28+
explicit TrilTriuOp(const std::string &op_type) : OpLite(op_type) {}
29+
30+
bool CheckShape() const override;
31+
32+
bool InferShapeImpl() const override;
33+
34+
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
35+
36+
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
37+
std::string DebugString() const override { return "tril_triu"; }
38+
39+
private:
40+
mutable TrilTriuParam param_;
41+
};
42+
43+
} // namespace operators
44+
} // namespace lite
45+
} // namespace paddle

lite/tests/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ if(LITE_BUILD_EXTRA)
9999
lite_cc_test(test_kernel_sequence_expand_as_compute SRCS sequence_expand_as_compute_test.cc DEPS ${test_kernel_deps})
100100
lite_cc_test(test_kernel_sin_compute SRCS sin_compute_test.cc DEPS arena_framework ${test_kernel_deps})
101101
lite_cc_test(test_kernel_cos_compute SRCS cos_compute_test.cc DEPS arena_framework ${test_kernel_deps})
102+
lite_cc_test(test_kernel_tril_triu_compute SRCS tril_triu_compute_test.cc DEPS arena_framework ${test_kernel_deps})
102103
lite_cc_test(test_kernel_pad3d_compute SRCS pad3d_compute_test.cc DEPS arena_framework ${test_kernel_deps})
103104
lite_cc_test(test_kernel_select_input_compute SRCS select_input_compute_test.cc DEPS arena_framework ${test_kernel_deps})
104105
# lite_cc_test(test_kernel_tensor_array_to_tensor_compute SRCS tensor_array_to_tensor_compute_test.cc DEPS arena_framework ${test_kernel_deps})

0 commit comments

Comments
 (0)