@@ -98,6 +98,23 @@ class GRUKernel : public framework::OpKernel<T> {
98
98
auto active_gate = math::detail::GetActivationType (
99
99
context.Attr <std::string>(" gate_activation" ));
100
100
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
101
+
102
+ // TODO(TJ): make a class, make one pack
103
+ T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
104
+ frame_size * 2 /* width of weight*/ ,
105
+ frame_size /* height of height*/ );
106
+ PADDLE_ENFORCE (packed_gate);
107
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
108
+ frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
109
+ packed_gate);
110
+ T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
111
+ frame_size /* width of weight*/ ,
112
+ frame_size /* height of height*/ );
113
+ PADDLE_ENFORCE (packed_state);
114
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
115
+ frame_size, T (1.0 ), gru_value.state_weight , frame_size,
116
+ packed_state);
117
+
101
118
for (size_t n = 0 ; n < num_batch; n++) {
102
119
int bstart = static_cast <int >(batch_starts[n]);
103
120
int bend = static_cast <int >(batch_starts[n + 1 ]);
@@ -110,20 +127,21 @@ class GRUKernel : public framework::OpKernel<T> {
110
127
gru_value.gate_value = gate_t .data <T>();
111
128
gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
112
129
if (gru_value.prev_out_value ) {
113
- blas.GEMM (false , false , cur_batch_size, frame_size * 2 , frame_size, 1 ,
114
- gru_value.prev_out_value , frame_size, gru_value.gate_weight ,
115
- frame_size * 2 , 1 , gru_value.gate_value , frame_size * 3 );
130
+ blas.GEMM_COMPUTE (CblasNoTrans, CblasPacked, cur_batch_size,
131
+ frame_size * 2 , frame_size, gru_value.prev_out_value ,
132
+ frame_size, packed_gate, frame_size * 2 , T (1 ),
133
+ gru_value.gate_value , frame_size * 3 );
116
134
}
117
135
118
136
math::detail::forward_reset_output (
119
137
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
120
138
cur_batch_size, active_gate);
121
139
122
140
if (gru_value.prev_out_value ) {
123
- blas.GEMM ( false , false , cur_batch_size, frame_size, frame_size, 1 ,
124
- gru_value. reset_output_value , frame_size,
125
- gru_value.state_weight , frame_size, 1 ,
126
- gru_value.gate_value + frame_size * 2 , frame_size * 3 );
141
+ blas.GEMM_COMPUTE (
142
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size , frame_size,
143
+ gru_value.reset_output_value , frame_size, packed_state, frame_size ,
144
+ T ( 1 ), gru_value.gate_value + frame_size * 2 , frame_size * 3 );
127
145
}
128
146
129
147
math::detail::forward_final_output (
@@ -132,6 +150,8 @@ class GRUKernel : public framework::OpKernel<T> {
132
150
133
151
gru_value.prev_out_value = gru_value.output_value ;
134
152
}
153
+ blas.GEMM_FREE (packed_gate);
154
+ blas.GEMM_FREE (packed_state);
135
155
136
156
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
137
157
batch_hidden->set_lod (batch_gate->lod ());
0 commit comments