@@ -110,8 +110,19 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
110110
111111 const int w_ch_out_mem_stride_from_tensors[] = {(int )weights_in->mem_stride [KRNL_RNN_W_IN_ELEMS_DIM],
112112 (int )weights_out->mem_stride [KRNL_RNN_W_IN_ELEMS_DIM]};
113- const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0 ] != 0 ) ? w_ch_out_mem_stride_from_tensors[0 ] : gru_out_elements,
114- (w_ch_out_mem_stride_from_tensors[1 ] != 0 ) ? w_ch_out_mem_stride_from_tensors[1 ] : gru_out_elements};
113+
114+ const int w_gate_mem_stride_from_tensors[] = {(int )weights_in->mem_stride [0 ],
115+ (int )weights_out->mem_stride [0 ]};
116+
117+ const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0 ] != 0 )
118+ ? w_ch_out_mem_stride_from_tensors[0 ] : gru_out_elements,
119+ (w_ch_out_mem_stride_from_tensors[1 ] != 0 )
120+ ? w_ch_out_mem_stride_from_tensors[1 ] : gru_out_elements};
121+
122+ const int w_gate_mem_strides[] = {(w_gate_mem_stride_from_tensors[0 ] != 0 )
123+ ? w_gate_mem_stride_from_tensors[0 ] : gru_out_elements * inputs_elements[0 ],
124+ (w_gate_mem_stride_from_tensors[1 ] != 0 )
125+ ? w_gate_mem_stride_from_tensors[1 ]: gru_out_elements * inputs_elements[1 ]};
115126
116127 // Paricular subtensors of intermediate tensor (mli_tensor.mem_stride[] should be zero and cannot be left uninitialized)
117128 mli_tensor reset_gate = {{ 0 }}, update_gate = {{ 0 }}, new_gate = {{ 0 }}; // Various gates to control info flow
@@ -123,13 +134,29 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
123134 mli_hlp_point_to_subtensor (&ir_tensor, &iterator, &update_gate); iterator.start_coord [0 ]++;
124135 mli_hlp_point_to_subtensor (&ir_tensor, &iterator, &reset_gate); iterator.start_coord [0 ]++;
125136 mli_hlp_point_to_subtensor (&ir_tensor, &iterator, &new_gate); iterator.start_coord [0 ]++;
126-
127- mli_hlp_point_to_subtensor (weights_in, &weight_iterator, &w_in_new_g);
128- mli_hlp_point_to_subtensor (weights_out, &weight_iterator, &w_out_new_g);
129137 mli_hlp_point_to_subtensor (bias, &weight_iterator, &b_new_g);
130138
131- const MLI_PTR (w_T) w_new_g_ptr[] = {mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(&w_in_new_g),
132- mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(&w_out_new_g)};
139+ w_in_new_g.data = weights_in->data ;
140+ w_in_new_g.rank = 2 ;
141+ w_in_new_g.shape [0 ] = weights_in->shape [1 ];
142+ w_in_new_g.shape [1 ] = weights_in->shape [2 ];
143+ w_in_new_g.el_params = weights_in->el_params ;
144+ w_in_new_g.el_type = weights_in->el_type ;
145+ mli_prv_tensor_inc_data_ptr<w_T*>(&w_in_new_g, num_gates * w_gate_mem_strides[0 ]);
146+
147+ w_out_new_g.data = weights_out->data ;
148+ w_out_new_g.rank = 2 ;
149+ w_out_new_g.shape [0 ] = weights_out->shape [1 ];
150+ w_out_new_g.shape [1 ] = weights_out->shape [2 ];
151+ w_out_new_g.el_params = weights_out->el_params ;
152+ w_out_new_g.el_type = weights_out->el_type ;
153+ mli_prv_tensor_inc_data_ptr<w_T*>(&w_out_new_g, num_gates * w_gate_mem_strides[1 ]);
154+
155+ const MLI_PTR (w_T) w_new_g_ptr[] = {
156+ mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights_in) + num_gates * w_gate_mem_strides[0 ],
157+ mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights_out) + num_gates * w_gate_mem_strides[1 ]
158+ };
159+
133160 const MLI_PTR (b_T) b_new_g_ptr = mli_prv_tensor_data_ptr<MLI_PTR (b_T)>(&b_new_g);
134161
135162 mli_tensor rnn_out = {{ 0 }};
@@ -172,7 +199,7 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
172199 // =======================================
173200 mli::krn::rnn_dense_op_stacked<io_T, w_T, b_T, acc_T, quant_T>(
174201 inputs_ptr, weights, bias, num_gates, num_inputs, inputs_elements,
175- in_to_out_params, w_ch_out_mem_strides, &ir_tensor);
202+ in_to_out_params, w_ch_out_mem_strides, w_gate_mem_strides, &ir_tensor);
176203
177204 // Step 2: Applying non-linearity
178205 // =======================================
@@ -256,7 +283,7 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
256283 mli::krn::eltwise_prepare_and_run<io_T, ELTWISE_MUL, /* convert*/ asym>(&new_gate, &update_gate, &temp);
257284 mli::krn::eltwise_prepare_and_run<io_T, ELTWISE_ADD, /* convert*/ asym>(&temp, ¤t_out, &rnn_out);
258285
259- current_hidden.data . mem . void_p = rnn_out.data . mem . void_p ;
286+ current_hidden.data = rnn_out.data ;
260287 current_hidden.el_params = rnn_out.el_params ;
261288
262289 // Step 6: Update pointers and tensors for next batch
0 commit comments