@@ -16,7 +16,10 @@ limitations under the License. */
16
16
#include < string>
17
17
#include " paddle/fluid/framework/eigen.h"
18
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/operators/math/blas.h"
19
20
#include " paddle/fluid/operators/math/detail/activation_functions.h"
21
+ #include " paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
22
+ #include " paddle/fluid/operators/math/detail/gru_kernel.h"
20
23
#include " paddle/fluid/operators/math/gru_compute.h"
21
24
#include " paddle/fluid/operators/math/math_function.h"
22
25
#include " paddle/fluid/operators/math/sequence2batch.h"
@@ -94,6 +97,7 @@ class GRUKernel : public framework::OpKernel<T> {
94
97
context.Attr <std::string>(" activation" ));
95
98
auto active_gate = math::detail::GetActivationType (
96
99
context.Attr <std::string>(" gate_activation" ));
100
+ auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
97
101
for (size_t n = 0 ; n < num_batch; n++) {
98
102
int bstart = static_cast <int >(batch_starts[n]);
99
103
int bend = static_cast <int >(batch_starts[n + 1 ]);
@@ -105,9 +109,27 @@ class GRUKernel : public framework::OpKernel<T> {
105
109
gru_value.output_value = hidden_t .data <T>();
106
110
gru_value.gate_value = gate_t .data <T>();
107
111
gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
108
- math::GRUUnitFunctor<DeviceContext, T>::compute (
109
- dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
110
- active_gate);
112
+ if (gru_value.prev_out_value ) {
113
+ blas.GEMM (false , false , cur_batch_size, frame_size * 2 , frame_size, 1 ,
114
+ gru_value.prev_out_value , frame_size, gru_value.gate_weight ,
115
+ frame_size * 2 , 1 , gru_value.gate_value , frame_size * 3 );
116
+ }
117
+
118
+ math::detail::forward_reset_output (
119
+ math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
120
+ cur_batch_size, active_gate);
121
+
122
+ if (gru_value.prev_out_value ) {
123
+ blas.GEMM (false , false , cur_batch_size, frame_size, frame_size, 1 ,
124
+ gru_value.reset_output_value , frame_size,
125
+ gru_value.state_weight , frame_size, 1 ,
126
+ gru_value.gate_value + frame_size * 2 , frame_size * 3 );
127
+ }
128
+
129
+ math::detail::forward_final_output (
130
+ math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
131
+ cur_batch_size, active_node);
132
+
111
133
gru_value.prev_out_value = gru_value.output_value ;
112
134
}
113
135
0 commit comments