Skip to content

Commit 30a4584

Browse files
Hakim7267dzakhar
authored andcommitted
kernel should not change the cell tensor shape
1 parent 8a60337 commit 30a4584

File tree

2 files changed

+8
-28
lines changed

2 files changed

+8
-28
lines changed

lib/src/kernels/common/impl/mli_krn_lstm_cell_ref.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,11 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
113113

114114
mli_tensor rnn_out;
115115
rnn_out.data = out->data;
116-
rnn_out.rank = 2;
117-
rnn_out.shape[0] = 1;
118-
rnn_out.shape[1] = lstm_out_elements;
119-
rnn_out.mem_stride[0] = rnn_out.shape[1];
120-
rnn_out.mem_stride[1] = 1;
116+
rnn_out.rank = 1;
117+
rnn_out.shape[0] = lstm_out_elements;
118+
rnn_out.mem_stride[0] = 1;
121119
rnn_out.el_type = in->el_type;
122120

123-
cell->rank = 2;
124-
cell->shape[0] = forget_gate.shape[0];
125-
cell->shape[1] = forget_gate.shape[1];
126-
cell->mem_stride[0] = forget_gate.mem_stride[0];
127-
cell->mem_stride[1] = forget_gate.mem_stride[1];
128-
129121
for (int timestep = 0; timestep < seq_len; timestep++) {
130122

131123
// Step 1: Applying Dense
@@ -173,9 +165,7 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
173165
temp.data = rnn_out.data;
174166
temp.rank = rnn_out.rank;
175167
temp.shape[0] = rnn_out.shape[0];
176-
temp.shape[1] = rnn_out.shape[1];
177168
temp.mem_stride[0] = rnn_out.mem_stride[0];
178-
temp.mem_stride[1] = rnn_out.mem_stride[1];
179169
temp.el_type = rnn_out.el_type;
180170
temp.el_params = out->el_params;
181171

@@ -249,4 +239,4 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
249239
} // namespace mli
250240
} // namespace krn
251241

252-
#endif //_MLI_KRN_LSTM_CELL_REF_H_
242+
#endif //_MLI_KRN_LSTM_CELL_REF_H_

lib/src/kernels/common/impl/mli_krn_lstm_cell_vdsp.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,11 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
123123

124124
mli_tensor rnn_out;
125125
rnn_out.data = out->data;
126-
rnn_out.rank = 2;
127-
rnn_out.shape[0] = 1;
128-
rnn_out.shape[1] = lstm_out_elements;
129-
rnn_out.mem_stride[0] = rnn_out.shape[1];
130-
rnn_out.mem_stride[1] = 1;
126+
rnn_out.rank = 1;
127+
rnn_out.shape[0] = lstm_out_elements;
128+
rnn_out.mem_stride[0] = 1;
131129
rnn_out.el_type = in->el_type;
132130

133-
cell->rank = 2;
134-
cell->shape[0] = tmp_gate.shape[0];
135-
cell->shape[1] = tmp_gate.shape[1];
136-
cell->mem_stride[0] = tmp_gate.mem_stride[0];
137-
cell->mem_stride[1] = tmp_gate.mem_stride[1];
138-
139131
struct s8asym_quant_params out_params_sigm;
140132
struct s8asym_quant_params out_params_tanh;
141133

@@ -236,9 +228,7 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
236228
temp.data = rnn_out.data;
237229
temp.rank = rnn_out.rank;
238230
temp.shape[0] = rnn_out.shape[0];
239-
temp.shape[1] = rnn_out.shape[1];
240231
temp.mem_stride[0] = rnn_out.mem_stride[0];
241-
temp.mem_stride[1] = rnn_out.mem_stride[1];
242232
temp.el_type = rnn_out.el_type;
243233
temp.el_params = out->el_params;
244234

@@ -314,4 +304,4 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
314304
} // namespace mli
315305
} // namespace krn
316306

317-
#endif //_MLI_KRN_LSTM_CELL_VDSP_H_
307+
#endif //_MLI_KRN_LSTM_CELL_VDSP_H_

0 commit comments

Comments
 (0)