Skip to content

Commit 574883b

Browse files
authored
[cherry-pick][XPU] add xpu gru kernel (#5687) (#5730)
1 parent b4088cc commit 574883b

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

lite/kernels/xpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ else()
2020
add_kernel(scale_compute_xpu XPU basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
2121
add_kernel(dropout_compute_xpu XPU basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps})
2222
add_kernel(matmul_compute_xpu XPU basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps})
23+
add_kernel(gru_compute_xpu XPU basic SRCS gru_compute.cc DEPS ${lite_kernel_deps})
2324
add_kernel(stack_compute_xpu XPU basic SRCS stack_compute.cc DEPS ${lite_kernel_deps})
2425
add_kernel(slice_compute_xpu XPU basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
2526
add_kernel(cast_compute_xpu XPU basic SRCS cast_compute.cc DEPS ${lite_kernel_deps})

lite/kernels/xpu/gru_compute.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/kernels/xpu/gru_compute.h"
16+
17+
REGISTER_LITE_KERNEL(
18+
gru, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::GRUCompute, def)
19+
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
20+
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kXPU))})
21+
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kXPU))})
22+
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
23+
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kXPU))})
24+
.BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kXPU))})
25+
.BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kXPU))})
26+
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kXPU))})
27+
.Finalize();

lite/kernels/xpu/gru_compute.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WIfloatHOUfloat WARRANfloatIES OR CONDIfloatIONS OF ANY KIND, either express
12+
// or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include "lite/backends/xpu/xpu_header_sitter.h"
19+
#include "lite/core/kernel.h"
20+
#include "lite/core/op_registry.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace kernels {
25+
namespace xpu {
26+
27+
class GRUCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
28+
float weight_s1_abs_max = -1;
29+
float weight_s2_abs_max = -1;
30+
31+
public:
32+
void Run() override {
33+
auto& ctx = this->ctx_->As<XPUContext>();
34+
auto& param = *param_.get_mutable<operators::GRUParam>();
35+
36+
bool origin_mode = param.origin_mode;
37+
bool is_reverse = param.is_reverse;
38+
39+
auto* input = param.input;
40+
const float* input_data = input->data<float>();
41+
auto* h0 = param.h0;
42+
CHECK_EQ((void*)h0, (void*)nullptr) << "h0 should be nullptr for XPU";
43+
44+
auto* weight = param.weight;
45+
const float* weight_data = weight->data<float>();
46+
auto* bias = param.bias;
47+
const float* bias_data = bias->data<float>();
48+
49+
auto* hidden = param.hidden;
50+
float* hidden_ptr = hidden->mutable_data<float>(TARGET(kXPU));
51+
const auto& hidden_dims = hidden->dims();
52+
int frame_size = hidden_dims[1];
53+
54+
auto& input_lod = input->lod()[0];
55+
int batch_size = input_lod.size() - 1;
56+
for (int i = 0; i < batch_size; i++) {
57+
int cur_seq_len = input_lod[i + 1] - input_lod[i];
58+
int ret = xdnn::gru_unit_int16(
59+
ctx.GetRawContext(), // Context *ctx,
60+
cur_seq_len, // int seq_len,
61+
frame_size, // int frame_size,
62+
is_reverse, // bool is_reverse,
63+
origin_mode, // bool origin_mode,
64+
const_cast<float*>(input_data), // float *input, // [seq_len, 3D]
65+
const_cast<float*>(weight_data), // float *weight, // [D, 3D]
66+
weight_s1_abs_max, // float& weight_s1_abs_max, // [D, 2D]
67+
weight_s2_abs_max, // float& weight_s2_abs_max, // [D, D]
68+
const_cast<float*>(bias_data), // float *bias, // [1, 3D]
69+
hidden_ptr); // float *hidden // [seq_len, D]
70+
CHECK_EQ(ret, 0) << "call xdnn::gru_unit_int16 failed!";
71+
input_data += cur_seq_len * 3 * frame_size;
72+
hidden_ptr += cur_seq_len * frame_size;
73+
}
74+
// batch_gate, batch_reset_hidden_prev lod not set
75+
hidden->set_lod(input->lod());
76+
}
77+
78+
virtual ~GRUCompute() = default;
79+
};
80+
81+
} // namespace xpu
82+
} // namespace kernels
83+
} // namespace lite
84+
} // namespace paddle

0 commit comments

Comments
 (0)