Skip to content

Commit b2e9776

Browse files
authored
[cherry-pick] support bigru on ARM and X86. (#7212)
1 parent 28578c7 commit b2e9776

39 files changed

+3165
-1168
lines changed

lite/api/cxx_api.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ void Predictor::ClearTensorArray(
525525
for (size_t var_idx = 0; var_idx < block->VarsSize(); var_idx++) {
526526
const cpp::VarDesc *var = block->GetVar<cpp::VarDesc>(var_idx);
527527
CHECK(var);
528-
if (var->GetType() == lite::VarDataType::LOD_TENSOR_ARRAY) {
528+
529+
auto tmp = program_->exec_scope()->FindVar(var->Name());
530+
if (tmp->IsType<std::vector<Tensor>>()) {
529531
std::vector<Tensor> *tensor_array_var =
530532
program_->exec_scope()->FindMutableTensorList(var->Name());
531533
CHECK(tensor_array_var);

lite/api/light_api.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ void LightPredictor::ClearTensorArray(
400400
for (size_t var_idx = 0; var_idx < block->VarsSize(); var_idx++) {
401401
const cpp::VarDesc* var = block->GetVar<cpp::VarDesc>(var_idx);
402402
CHECK(var);
403-
if (var->GetType() == lite::VarDataType::LOD_TENSOR_ARRAY) {
403+
404+
auto tmp = program_->exec_scope()->FindVar(var->Name());
405+
if (tmp->IsType<std::vector<Tensor>>()) {
404406
std::vector<Tensor>* tensor_array_var =
405407
program_->exec_scope()->FindMutableTensorList(var->Name());
406408
CHECK(tensor_array_var);

lite/backends/arm/math/gru.h

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "lite/backends/arm/math/sgemm.h"
18+
#ifdef LITE_WITH_ARM
19+
#include <arm_neon.h>
20+
#endif
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace arm {
25+
namespace math {
26+
27+
template <typename T>
28+
struct RNNGRUValue {
29+
const T* gate_weight;
30+
const T* state_weight;
31+
const T* reset_bias;
32+
T* gate_value;
33+
T* reset_output_value;
34+
T* output_value;
35+
const T* prev_out_value;
36+
};
37+
38+
template <typename T>
39+
void rnn_activation(const T* din,
40+
T* dout,
41+
int size,
42+
lite_api::ActivationType act_type,
43+
int threads) {
44+
switch (act_type) {
45+
case lite_api::ActivationType::kSigmoid:
46+
act_sigmoid(din, dout, size, threads);
47+
break;
48+
case lite_api::ActivationType::kSigmoid_v2:
49+
act_sigmoid(din, dout, size, threads);
50+
break;
51+
case lite_api::ActivationType::kTanh:
52+
act_tanh(din, dout, size, threads);
53+
break;
54+
case lite_api::ActivationType::kTanh_v2:
55+
act_tanh(din, dout, size, threads);
56+
break;
57+
case lite_api::ActivationType::kRelu:
58+
act_relu(din, dout, size, threads);
59+
break;
60+
default:
61+
LOG(FATAL) << "unsupport activation type:" << static_cast<int>(act_type);
62+
break;
63+
}
64+
}
65+
66+
template <typename T>
67+
void compute_kernel(RNNGRUValue<T> value,
68+
int frame_size,
69+
int batch_size,
70+
lite_api::ActivationType active_node,
71+
lite_api::ActivationType active_gate) {
72+
auto value_reset_gate = value.gate_value;
73+
auto value_update_gate = value.gate_value + frame_size;
74+
auto value_reset_output = value.reset_output_value;
75+
auto value_reset_bias = value.reset_bias;
76+
auto cell_state_value = value.gate_value + 2 * frame_size;
77+
auto value_output = value.output_value;
78+
auto value_prev_out = value.prev_out_value;
79+
80+
for (int b = 0; b < batch_size; b++) {
81+
rnn_activation(value_reset_gate,
82+
value_reset_gate,
83+
frame_size,
84+
lite_api::ActivationType::kSigmoid_v2,
85+
1);
86+
rnn_activation(value_update_gate,
87+
value_update_gate,
88+
frame_size,
89+
lite_api::ActivationType::kSigmoid_v2,
90+
1);
91+
92+
for (int i = 0; i < frame_size; i++) {
93+
value_reset_output[i] =
94+
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
95+
cell_state_value[i] += value_reset_output[i];
96+
}
97+
98+
rnn_activation(cell_state_value,
99+
cell_state_value,
100+
frame_size,
101+
lite_api::ActivationType::kTanh_v2,
102+
1);
103+
104+
if (value.prev_out_value) {
105+
for (int i = 0; i < frame_size; i++) {
106+
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
107+
value_update_gate[i] * value_prev_out[i];
108+
}
109+
} else {
110+
for (int i = 0; i < frame_size; i++) {
111+
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
112+
}
113+
}
114+
115+
value_reset_gate += frame_size * 3;
116+
value_update_gate += frame_size * 3;
117+
value_reset_output += frame_size;
118+
cell_state_value += frame_size * 3;
119+
value_output += frame_size;
120+
if (value.prev_out_value) {
121+
value_prev_out += frame_size;
122+
}
123+
}
124+
}
125+
126+
template <>
127+
void compute_kernel<float>(RNNGRUValue<float> value,
128+
int frame_size,
129+
int batch_size,
130+
lite_api::ActivationType active_node,
131+
lite_api::ActivationType active_gate) {
132+
auto value_reset_gate = value.gate_value;
133+
auto value_update_gate = value.gate_value + frame_size;
134+
auto value_reset_output = value.reset_output_value;
135+
auto value_reset_bias = value.reset_bias;
136+
auto cell_state_value = value.gate_value + 2 * frame_size;
137+
auto value_output = value.output_value;
138+
auto value_prev_out = value.prev_out_value;
139+
int i = 0;
140+
float32x4_t vec_one = vdupq_n_f32(1.f);
141+
142+
for (int b = 0; b < batch_size; b++) {
143+
rnn_activation(value_reset_gate,
144+
value_reset_gate,
145+
frame_size,
146+
lite_api::ActivationType::kSigmoid_v2,
147+
1);
148+
rnn_activation(value_update_gate,
149+
value_update_gate,
150+
frame_size,
151+
lite_api::ActivationType::kSigmoid_v2,
152+
1);
153+
154+
for (i = 0; i + 3 < frame_size; i += 4) {
155+
float32x4_t vec_out = vld1q_f32(value_reset_output + i);
156+
float32x4_t vec_reset = vld1q_f32(value_reset_gate + i);
157+
float32x4_t vec_bias = vld1q_f32(value_reset_bias + i);
158+
vec_out = vmulq_f32(vaddq_f32(vec_out, vec_bias), vec_reset);
159+
vst1q_f32(value_reset_output + i, vec_out);
160+
vst1q_f32(cell_state_value + i,
161+
vaddq_f32(vec_out, vld1q_f32(cell_state_value + i)));
162+
}
163+
for (; i < frame_size; i++) {
164+
value_reset_output[i] =
165+
(value_reset_output[i] + value_reset_bias[i]) * value_reset_gate[i];
166+
cell_state_value[i] += value_reset_output[i];
167+
}
168+
169+
rnn_activation(cell_state_value,
170+
cell_state_value,
171+
frame_size,
172+
lite_api::ActivationType::kTanh_v2,
173+
1);
174+
175+
if (value.prev_out_value) {
176+
for (i = 0; i + 3 < frame_size; i += 4) {
177+
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
178+
float32x4_t vec_vpo = vld1q_f32(value_prev_out + i);
179+
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
180+
vec_vpo = vmulq_f32(vec_vug, vec_vpo);
181+
float32x4_t vec_out =
182+
vmlaq_f32(vec_vpo, vsubq_f32(vec_one, vec_vug), vec_csv);
183+
vst1q_f32(value_output + i, vec_out);
184+
}
185+
for (; i < frame_size; i++) {
186+
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i] +
187+
value_update_gate[i] * value_prev_out[i];
188+
}
189+
} else {
190+
for (i = 0; i + 3 < frame_size; i += 4) {
191+
float32x4_t vec_vug = vld1q_f32(value_update_gate + i);
192+
float32x4_t vec_csv = vld1q_f32(cell_state_value + i);
193+
float32x4_t vec_out = vmulq_f32(vsubq_f32(vec_one, vec_vug), vec_csv);
194+
vst1q_f32(value_output + i, vec_out);
195+
}
196+
for (; i < frame_size; i++) {
197+
value_output[i] = (1.f - value_update_gate[i]) * cell_state_value[i];
198+
}
199+
}
200+
201+
value_reset_gate += frame_size * 3;
202+
value_update_gate += frame_size * 3;
203+
value_reset_output += frame_size;
204+
cell_state_value += frame_size * 3;
205+
value_output += frame_size;
206+
if (value.prev_out_value) {
207+
value_prev_out += frame_size;
208+
}
209+
}
210+
}
211+
212+
template <typename T>
213+
struct RnnGruUnitFunctorV2 {
214+
static void compute(ARMContext* ctx,
215+
RNNGRUValue<T> value,
216+
int frame_size,
217+
int batch_size,
218+
lite_api::ActivationType active_node,
219+
lite_api::ActivationType active_gate) {
220+
if (value.prev_out_value) {
221+
operators::ActivationParam act_param;
222+
act_param.has_active = false;
223+
lite::arm::math::sgemm(false,
224+
true,
225+
batch_size,
226+
frame_size,
227+
frame_size,
228+
1.f,
229+
value.prev_out_value,
230+
frame_size,
231+
value.state_weight,
232+
frame_size,
233+
0.f,
234+
value.reset_output_value,
235+
frame_size,
236+
nullptr,
237+
false,
238+
act_param,
239+
ctx);
240+
}
241+
compute_kernel(value, frame_size, batch_size, active_node, active_gate);
242+
}
243+
};
244+
245+
} // namespace math
246+
} // namespace arm
247+
} // namespace lite
248+
} // namespace paddle

lite/backends/arm/math/lstm.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void add_bias_rowwise(Tensor* input,
3636
i_data += width;
3737
}
3838
}
39+
3940
void vector_dot(
4041
float* out, const float* in, const float* v1, int size, const float* v2) {
4142
int loop = size >> 2;

0 commit comments

Comments
 (0)