Skip to content

Commit a6568b9

Browse files
Dmitry NaumkinJaccovG
authored andcommitted
Patch for rnn_dense_op scales iteration and memstride weights iteration
1 parent b8105de commit a6568b9

File tree

11 files changed

+524
-185
lines changed

11 files changed

+524
-185
lines changed

lib/src/bricks/impl/mli_krn_rnn_dense_op_ref.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ static inline void adjust_weights_dim_for_rnn_dense(s8asym_quant_specific_params
3232
params->weight_dim = -1;
3333
}
3434

35+
static inline void adjust_weights_scale_for_rnn_dense(
36+
fx_quant_specific_params* params,
37+
fx_quant_specific_params* initial_params) {
38+
return;
39+
}
40+
41+
static inline void adjust_weights_scale_for_rnn_dense(
42+
s8asym_quant_specific_params* params,
43+
s8asym_quant_specific_params* initial_params) {
44+
if (initial_params->weight_dim != -1) {
45+
params->weight_scales++;
46+
params->weight_shifts++;
47+
}
48+
}
49+
50+
static inline void adjust_weights_scale_back_for_rnn_dense(
51+
fx_quant_specific_params* params,
52+
fx_quant_specific_params* initial_params,
53+
int gates) {
54+
return;
55+
}
56+
57+
static inline void adjust_weights_scale_back_for_rnn_dense(
58+
s8asym_quant_specific_params* params,
59+
s8asym_quant_specific_params* initial_params,
60+
int gates) {
61+
if(initial_params->weight_dim != -1) {
62+
params->weight_scales -= gates;
63+
params->weight_shifts -= gates;
64+
}
65+
}
66+
3567
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
3668
static inline void rnn_dense_op_stacked(
3769
const MLI_PTR (io_T) * inputs_ptr,
@@ -42,6 +74,7 @@ static inline void rnn_dense_op_stacked(
4274
const int * inputs_elements,
4375
quant_T * in_to_out_quant_params,
4476
const int * w_ch_out_mem_strides,
77+
const int * w_gate_mem_strides,
4578
mli_tensor * out) {
4679

4780
constexpr bool asym = std::is_same<quant_T, s8asym_quant_specific_params>::value;
@@ -50,20 +83,15 @@ static inline void rnn_dense_op_stacked(
5083
mli_minmax_t val_limit = mli_prv_get_relu_limits<io_T, asym>(&relu_none, out);
5184

5285
const MLI_PTR (w_T) weights_ptr[MLI_RNN_MAX_INPUT];
86+
quant_T initial_params[MLI_RNN_MAX_INPUT];
5387
uint32_t weights_shift[MLI_RNN_MAX_INPUT];
5488

55-
const int16_t * weights_scales[MLI_RNN_MAX_INPUT];
56-
const int8_t * weights_scale_frac_bits[MLI_RNN_MAX_INPUT];
57-
5889
int out_elements = mli_prv_count_elem_num_part(bias, 1);
5990

6091
for(int idx = 0; idx < inputs_num; ++idx) {
6192
weights_ptr[idx] = mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights[idx]);
62-
weights_shift[idx] = mli_prv_count_elem_num_part(weights[idx], 1);
63-
64-
weights_scales[idx] = weights[idx]->el_params.sa.scale.mem.pi16;
65-
weights_scale_frac_bits[idx] = weights[idx]->el_params.sa.scale_frac_bits.mem.pi8;
66-
93+
weights_shift[idx] = w_gate_mem_strides[idx];
94+
initial_params[idx] = in_to_out_quant_params[idx];
6795
adjust_weights_dim_for_rnn_dense(&in_to_out_quant_params[idx]);
6896
}
6997

@@ -76,22 +104,19 @@ static inline void rnn_dense_op_stacked(
76104
out_elements, w_ch_out_mem_strides, in_to_out_quant_params,
77105
(io_T)val_limit.min, (io_T)val_limit.max);
78106

79-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
107+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
80108
weights_ptr[weight_idx] += weights_shift[weight_idx];
109+
adjust_weights_scale_for_rnn_dense(&in_to_out_quant_params[weight_idx], &initial_params[weight_idx]);
110+
}
81111

82112
bias_ptr += out_elements;
83113
dense_out_ptr += out_elements;
84-
85-
if (asym) {
86-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
87-
weights_scales[weight_idx]++;
88-
weights_scale_frac_bits[weight_idx]++;
89-
}
90-
}
91114
}
92115

93-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
116+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
94117
weights_ptr[weight_idx] -= gates_num * weights_shift[weight_idx];
118+
adjust_weights_scale_back_for_rnn_dense(&in_to_out_quant_params[weight_idx], &initial_params[weight_idx], gates_num);
119+
}
95120

96121
bias_ptr -= gates_num * out_elements;
97122
dense_out_ptr -= gates_num * out_elements;

lib/src/bricks/impl/mli_krn_rnn_dense_op_vdsp.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,38 @@ static inline void adjust_weights_dim_for_rnn_dense(s8asym_quant_specific_params
2929
params->weight_dim = -1;
3030
}
3131

32+
static inline void adjust_weights_scale_for_rnn_dense(
33+
fx_quant_specific_params* params,
34+
fx_quant_specific_params* initial_params) {
35+
return;
36+
}
37+
38+
static inline void adjust_weights_scale_for_rnn_dense(
39+
s8asym_quant_specific_params* params,
40+
s8asym_quant_specific_params* initial_params) {
41+
if (initial_params->weight_dim != -1) {
42+
params->weight_scales++;
43+
params->weight_shifts++;
44+
}
45+
}
46+
47+
static inline void adjust_weights_scale_back_for_rnn_dense(
48+
fx_quant_specific_params* params,
49+
fx_quant_specific_params* initial_params,
50+
int gates) {
51+
return;
52+
}
53+
54+
static inline void adjust_weights_scale_back_for_rnn_dense(
55+
s8asym_quant_specific_params* params,
56+
s8asym_quant_specific_params* initial_params,
57+
int gates) {
58+
if(initial_params->weight_dim != -1) {
59+
params->weight_scales -= gates;
60+
params->weight_shifts -= gates;
61+
}
62+
}
63+
3264
template <typename io_T, typename w_T, typename b_T, typename acc_T, typename quant_T>
3365
static inline void rnn_dense_op_stacked(
3466
const MLI_PTR (io_T) * inputs_ptr,
@@ -39,6 +71,7 @@ static inline void rnn_dense_op_stacked(
3971
const int * inputs_elements,
4072
quant_T * in_to_out_quant_params,
4173
const int * w_ch_out_mem_strides,
74+
const int * w_gate_mem_strides,
4275
mli_tensor * out) {
4376

4477
constexpr bool asym = std::is_same<quant_T, s8asym_quant_specific_params>::value;
@@ -47,20 +80,15 @@ static inline void rnn_dense_op_stacked(
4780
mli_minmax_t val_limit = mli_prv_get_relu_limits<io_T, asym>(&relu_none, out);
4881

4982
const MLI_PTR (w_T) weights_ptr[MLI_RNN_MAX_INPUT];
83+
quant_T initial_params[MLI_RNN_MAX_INPUT];
5084
uint32_t weights_shift[MLI_RNN_MAX_INPUT];
5185

52-
const int16_t * weights_scales[MLI_RNN_MAX_INPUT];
53-
const int8_t * weights_scale_frac_bits[MLI_RNN_MAX_INPUT];
54-
5586
int out_elements = mli_prv_count_elem_num_part(bias, 1);
5687

5788
for(int idx = 0; idx < inputs_num; ++idx) {
5889
weights_ptr[idx] = mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights[idx]);
59-
weights_shift[idx] = mli_prv_count_elem_num_part(weights[idx], 1);
60-
61-
weights_scales[idx] = weights[idx]->el_params.sa.scale.mem.pi16;
62-
weights_scale_frac_bits[idx] = weights[idx]->el_params.sa.scale_frac_bits.mem.pi8;
63-
90+
weights_shift[idx] = w_gate_mem_strides[idx];
91+
initial_params[idx] = in_to_out_quant_params[idx];
6492
adjust_weights_dim_for_rnn_dense(&in_to_out_quant_params[idx]);
6593
}
6694

@@ -73,22 +101,19 @@ static inline void rnn_dense_op_stacked(
73101
out_elements, w_ch_out_mem_strides, in_to_out_quant_params,
74102
(io_T)val_limit.min, (io_T)val_limit.max);
75103

76-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
104+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
77105
weights_ptr[weight_idx] += weights_shift[weight_idx];
106+
adjust_weights_scale_for_rnn_dense(&in_to_out_quant_params[weight_idx], &initial_params[weight_idx]);
107+
}
78108

79109
bias_ptr += out_elements;
80110
dense_out_ptr += out_elements;
81-
82-
if (asym) {
83-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
84-
weights_scales[weight_idx]++;
85-
weights_scale_frac_bits[weight_idx]++;
86-
}
87-
}
88111
}
89112

90-
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx)
113+
for (int weight_idx = 0; weight_idx < inputs_num; ++weight_idx) {
91114
weights_ptr[weight_idx] -= gates_num * weights_shift[weight_idx];
115+
adjust_weights_scale_back_for_rnn_dense(&in_to_out_quant_params[weight_idx], &initial_params[weight_idx], gates_num);
116+
}
92117

93118
bias_ptr -= gates_num * out_elements;
94119
dense_out_ptr -= gates_num * out_elements;

lib/src/bricks/mli_krn_rnn_dense_op_decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ static MLI_FORCE_INLINE void rnn_dense_op_stacked(
5353
const int * inputs_elements,
5454
quant_T * in_to_out_quant_params,
5555
const int * w_ch_out_mem_strides,
56+
const int * w_gate_mem_strides,
5657
mli_tensor * out);
5758

5859
} // namespace ref
@@ -107,6 +108,7 @@ static MLI_FORCE_INLINE void rnn_dense_op_stacked(
107108
const int * inputs_elements,
108109
quant_T * in_to_out_quant_params,
109110
const int * w_ch_out_mem_strides,
111+
const int * w_gate_mem_strides,
110112
mli_tensor * out);
111113

112114
} // namespace vdsp

lib/src/kernels/common/mli_krn_gru_cell.h

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,19 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
110110

111111
const int w_ch_out_mem_stride_from_tensors[] = {(int)weights_in->mem_stride[KRNL_RNN_W_IN_ELEMS_DIM],
112112
(int)weights_out->mem_stride[KRNL_RNN_W_IN_ELEMS_DIM]};
113-
const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0] != 0) ? w_ch_out_mem_stride_from_tensors[0] : gru_out_elements,
114-
(w_ch_out_mem_stride_from_tensors[1] != 0) ? w_ch_out_mem_stride_from_tensors[1] : gru_out_elements};
113+
114+
const int w_gate_mem_stride_from_tensors[] = {(int)weights_in->mem_stride[0],
115+
(int)weights_out->mem_stride[0]};
116+
117+
const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0] != 0)
118+
? w_ch_out_mem_stride_from_tensors[0] : gru_out_elements,
119+
(w_ch_out_mem_stride_from_tensors[1] != 0)
120+
? w_ch_out_mem_stride_from_tensors[1] : gru_out_elements};
121+
122+
const int w_gate_mem_strides[] = {(w_gate_mem_stride_from_tensors[0] != 0)
123+
? w_gate_mem_stride_from_tensors[0] : gru_out_elements * inputs_elements[0],
124+
(w_gate_mem_stride_from_tensors[1] != 0)
125+
? w_gate_mem_stride_from_tensors[1]: gru_out_elements * inputs_elements[1]};
115126

116127
// Paricular subtensors of intermediate tensor (mli_tensor.mem_stride[] should be zero and cannot be left uninitialized)
117128
mli_tensor reset_gate = {{ 0 }}, update_gate = {{ 0 }}, new_gate = {{ 0 }}; // Various gates to control info flow
@@ -123,13 +134,29 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
123134
mli_hlp_point_to_subtensor(&ir_tensor, &iterator, &update_gate); iterator.start_coord[0]++;
124135
mli_hlp_point_to_subtensor(&ir_tensor, &iterator, &reset_gate); iterator.start_coord[0]++;
125136
mli_hlp_point_to_subtensor(&ir_tensor, &iterator, &new_gate); iterator.start_coord[0]++;
126-
127-
mli_hlp_point_to_subtensor(weights_in, &weight_iterator, &w_in_new_g);
128-
mli_hlp_point_to_subtensor(weights_out, &weight_iterator, &w_out_new_g);
129137
mli_hlp_point_to_subtensor(bias, &weight_iterator, &b_new_g);
130138

131-
const MLI_PTR (w_T) w_new_g_ptr[] = {mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(&w_in_new_g),
132-
mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(&w_out_new_g)};
139+
w_in_new_g.data = weights_in->data;
140+
w_in_new_g.rank = 2;
141+
w_in_new_g.shape[0] = weights_in->shape[1];
142+
w_in_new_g.shape[1] = weights_in->shape[2];
143+
w_in_new_g.el_params = weights_in->el_params;
144+
w_in_new_g.el_type = weights_in->el_type;
145+
mli_prv_tensor_inc_data_ptr<w_T*>(&w_in_new_g, num_gates * w_gate_mem_strides[0]);
146+
147+
w_out_new_g.data = weights_out->data;
148+
w_out_new_g.rank = 2;
149+
w_out_new_g.shape[0] = weights_out->shape[1];
150+
w_out_new_g.shape[1] = weights_out->shape[2];
151+
w_out_new_g.el_params = weights_out->el_params;
152+
w_out_new_g.el_type = weights_out->el_type;
153+
mli_prv_tensor_inc_data_ptr<w_T*>(&w_out_new_g, num_gates * w_gate_mem_strides[1]);
154+
155+
const MLI_PTR (w_T) w_new_g_ptr[] = {
156+
mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights_in) + num_gates * w_gate_mem_strides[0],
157+
mli_prv_tensor_data_ptr<MLI_PTR (w_T)>(weights_out) + num_gates * w_gate_mem_strides[1]
158+
};
159+
133160
const MLI_PTR (b_T) b_new_g_ptr = mli_prv_tensor_data_ptr<MLI_PTR (b_T)>(&b_new_g);
134161

135162
mli_tensor rnn_out = {{ 0 }};
@@ -172,7 +199,7 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
172199
//=======================================
173200
mli::krn::rnn_dense_op_stacked<io_T, w_T, b_T, acc_T, quant_T>(
174201
inputs_ptr, weights, bias, num_gates, num_inputs, inputs_elements,
175-
in_to_out_params, w_ch_out_mem_strides, &ir_tensor);
202+
in_to_out_params, w_ch_out_mem_strides, w_gate_mem_strides, &ir_tensor);
176203

177204
// Step 2: Applying non-linearity
178205
//=======================================
@@ -256,7 +283,7 @@ MLI_FORCE_INLINE void gru_cell_prepare_and_run(
256283
mli::krn::eltwise_prepare_and_run<io_T, ELTWISE_MUL, /*convert*/ asym>(&new_gate, &update_gate, &temp);
257284
mli::krn::eltwise_prepare_and_run<io_T, ELTWISE_ADD, /*convert*/ asym>(&temp, &current_out, &rnn_out);
258285

259-
current_hidden.data.mem.void_p = rnn_out.data.mem.void_p;
286+
current_hidden.data = rnn_out.data;
260287
current_hidden.el_params = rnn_out.el_params;
261288

262289
// Step 6: Update pointers and tensors for next batch

lib/src/kernels/common/mli_krn_lstm_cell.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,19 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
9090

9191
const int w_ch_out_mem_stride_from_tensors[] = {(int)weights_in->mem_stride[KRNL_RNN_W_IN_ELEMS_DIM],
9292
(int)weights_out->mem_stride[KRNL_RNN_W_IN_ELEMS_DIM]};
93-
const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0] != 0) ? w_ch_out_mem_stride_from_tensors[0] : lstm_out_elements,
94-
(w_ch_out_mem_stride_from_tensors[1] != 0) ? w_ch_out_mem_stride_from_tensors[1] : lstm_out_elements};
93+
94+
const int w_gate_mem_stride_from_tensors[] = {(int)weights_in->mem_stride[0],
95+
(int)weights_out->mem_stride[0]};
96+
97+
const int w_ch_out_mem_strides[] = {(w_ch_out_mem_stride_from_tensors[0] != 0)
98+
? w_ch_out_mem_stride_from_tensors[0] : lstm_out_elements,
99+
(w_ch_out_mem_stride_from_tensors[1] != 0)
100+
? w_ch_out_mem_stride_from_tensors[1]: lstm_out_elements};
101+
102+
const int w_gate_mem_strides[] = {(w_gate_mem_stride_from_tensors[0] != 0)
103+
? w_gate_mem_stride_from_tensors[0] : lstm_out_elements * inputs_elements[0],
104+
(w_gate_mem_stride_from_tensors[1] != 0)
105+
? w_gate_mem_stride_from_tensors[1]: lstm_out_elements * inputs_elements[1]};
95106

96107
// Paricular subtensors of intermediate tensor (mli_tensor.mem_stride[] should be zero and cannot be left uninitialized)
97108
mli_tensor in_gate = {{ 0 }}, forget_gate = {{ 0 }}, out_gate = {{ 0 }}; // Various gates to controll info flow
@@ -119,7 +130,7 @@ MLI_FORCE_INLINE void lstm_cell_prepare_and_run(
119130
//=======================================
120131
rnn_dense_op_stacked<io_T, w_T, b_T, acc_T, quant_T>(
121132
inputs_ptr, weights, bias, num_gates, num_inputs, inputs_elements,
122-
in_to_out_params, w_ch_out_mem_strides, &ir_tensor);
133+
in_to_out_params, w_ch_out_mem_strides, w_gate_mem_strides, &ir_tensor);
123134

124135

125136
// Step 2: Applying non-linearity

0 commit comments

Comments
 (0)