Skip to content

Commit 3fcc27f

Browse files
authored
add_dequantize_log test=develop (#9669)
1 parent 1eea6b2 commit 3fcc27f

File tree

8 files changed

+215
-3
lines changed

8 files changed

+215
-3
lines changed

lite/kernels/arm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ add_kernel(scatter_compute_arm ARM extra SRCS scatter_compute.cc)
7575
add_kernel(sequence_expand_as_compute_arm ARM extra SRCS sequence_expand_as_compute.cc)
7676
add_kernel(matmul_v2_compute ARM extra SRCS matmul_v2_compute.cc)
7777
add_kernel(sum_compute ARM extra SRCS sum_compute.cc)
78-
78+
add_kernel(dequantize_log_compute ARM extra SRCS dequantize_log_compute.cc)
7979

8080
# for OCR specific
8181
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/arm/dequantize_log_compute.h"
16+
#include <set>
17+
#include <string>
18+
#include <vector>
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace arm {
24+
25+
template <typename T>
26+
void DequantizeLogCompute<T>::Run() {
27+
auto& param = Param<operators::QuantizeLogParam>();
28+
auto x = param.X;
29+
auto dict = param.Dict;
30+
auto output = param.Out;
31+
const float* dict_data = dict->template data<float>();
32+
const T* input_data = x->template data<T>();
33+
float* output_data = output->template mutable_data<float>();
34+
int ind = x->numel();
35+
for (size_t i = 0; i < (unsigned)ind; i++) {
36+
if (input_data[i] < 0) {
37+
output_data[i] = -dict_data[input_data[i] + 128];
38+
} else {
39+
output_data[i] = dict_data[input_data[i]];
40+
}
41+
}
42+
}
43+
44+
} // namespace arm
45+
} // namespace kernels
46+
} // namespace lite
47+
} // namespace paddle
48+
49+
REGISTER_LITE_KERNEL(dequantize_log,
50+
kARM,
51+
kInt8,
52+
kNCHW,
53+
paddle::lite::kernels::arm::DequantizeLogCompute<int8_t>,
54+
def)
55+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
56+
.BindInput("Dict", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
57+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
58+
.BindPaddleOpVersion("dequantize_log", 1)
59+
.Finalize();
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <stdint.h>
17+
#include "lite/core/kernel.h"
18+
#include "lite/core/op_registry.h"
19+
20+
namespace paddle {
21+
namespace lite {
22+
namespace kernels {
23+
namespace arm {
24+
25+
template <typename T>
26+
class DequantizeLogCompute : public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
27+
public:
28+
void Run() override;
29+
30+
virtual ~DequantizeLogCompute() = default;
31+
32+
private:
33+
};
34+
35+
} // namespace arm
36+
} // namespace kernels
37+
} // namespace lite
38+
} // namespace paddle

lite/kernels/arm/lookup_table_compute.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ void LookupTableCompute<T_W, T_IDS>::Run() {
7070
memcpy(dout + i * row_width, table_data, row_width * sizeof(float));
7171
}
7272
#else
73-
auto table_data = w->template data<float>();
73+
auto table_data = w->template data<T_W>();
7474
memcpy(dout + i * row_width,
7575
table_data + ids_int * row_width,
76-
row_width * sizeof(float));
76+
row_width * sizeof(T_W));
7777
#endif
7878
}
7979
}
@@ -87,6 +87,8 @@ void LookupTableCompute<T_W, T_IDS>::Run() {
8787

8888
using LookupTableFloatInt64 =
8989
paddle::lite::kernels::arm::LookupTableCompute<float, int64_t>;
90+
using LookupTableInt8Int64 =
91+
paddle::lite::kernels::arm::LookupTableCompute<int8_t, int64_t>;
9092
using LookupTableFloatInt32 =
9193
paddle::lite::kernels::arm::LookupTableCompute<float, int32_t>;
9294

@@ -105,6 +107,14 @@ REGISTER_LITE_KERNEL(
105107
.BindPaddleOpVersion("lookup_table_v2", 1)
106108
.Finalize();
107109

110+
REGISTER_LITE_KERNEL(
111+
lookup_table_v2, kARM, kAny, kNCHW, LookupTableInt8Int64, int8_int64)
112+
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
113+
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
114+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
115+
.BindPaddleOpVersion("lookup_table_v2", 1)
116+
.Finalize();
117+
108118
REGISTER_LITE_KERNEL(
109119
lookup_table, kARM, kAny, kNCHW, LookupTableFloatInt32, float_int32)
110120
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})

lite/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ add_operator(fake_channel_wise_quantize_dequantize_abs_max_op extra SRCS fake_ch
108108
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc)
109109
add_operator(quantize_linear_op extra SRCS quantize_linear_op.cc)
110110
add_operator(dequantize_linear_op extra SRCS dequantize_linear_op.cc)
111+
add_operator(dequantize_log_op extra SRCS dequantize_log_op.cc)
111112
add_operator(split_lod_tensor_op_lite extra SRCS split_lod_tensor_op.cc)
112113
add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc)
113114
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc)

lite/operators/dequantize_log_op.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/operators/dequantize_log_op.h"
16+
#include "lite/core/op_registry.h"
17+
18+
namespace paddle {
19+
namespace lite {
20+
namespace operators {
21+
22+
bool DequantizeLogOpLite::CheckShape() const {
23+
CHECK_OR_FALSE(param_.X);
24+
CHECK_OR_FALSE(param_.Out);
25+
return true;
26+
}
27+
28+
bool DequantizeLogOpLite::InferShape() {
29+
lite::DDim x_dims = param_.X->dims();
30+
param_.Out->Resize(x_dims);
31+
32+
return true;
33+
}
34+
} // namespace operators
35+
} // namespace lite
36+
} // namespace paddle
37+
38+
REGISTER_LITE_OP(dequantize_log, paddle::lite::operators::DequantizeLogOpLite);

lite/operators/dequantize_log_op.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
#include "lite/core/kernel.h"
20+
#include "lite/core/op_lite.h"
21+
#include "lite/core/scope.h"
22+
#include "lite/core/tensor.h"
23+
#include "lite/operators/op_params.h"
24+
#include "lite/utils/all.h"
25+
26+
namespace paddle {
27+
namespace lite {
28+
namespace operators {
29+
30+
class DequantizeLogOpLite : public OpLite {
31+
public:
32+
DequantizeLogOpLite() {}
33+
34+
explicit DequantizeLogOpLite(const std::string &type) : OpLite(type) {}
35+
36+
bool CheckShape() const override;
37+
38+
bool InferShape() override;
39+
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
40+
auto x = op_desc.Input("X").front();
41+
auto dict = op_desc.Input("Dict").front();
42+
auto out = op_desc.Output("Out").front();
43+
44+
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
45+
param_.Dict = scope->FindVar(dict)->GetMutable<lite::Tensor>();
46+
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
47+
return true;
48+
}
49+
50+
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
51+
52+
std::string DebugString() const override { return "dequantize_log"; }
53+
54+
private:
55+
mutable QuantizeLogParam param_;
56+
};
57+
58+
} // namespace operators
59+
} // namespace lite
60+
} // namespace paddle

lite/operators/op_params.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,12 @@ struct QuantizeLinearParam : ParamBase {
667667
int bit_length;
668668
};
669669

670+
struct QuantizeLogParam : ParamBase {
671+
const lite::Tensor* X{};
672+
const lite::Tensor* Dict{};
673+
lite::Tensor* Out{};
674+
};
675+
670676
/// ----------------------- sgd operators ----------------------
671677
struct SGDParam : ParamBase {
672678
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};

0 commit comments

Comments
 (0)