Skip to content

Commit 1250806

Browse files
authored
Split normquant ptrs (#4)
* Fix remove unused import * Split set_ptrs into set_ptrs_conv and set_ptrs_norm_quant
1 parent a8765d1 commit 1250806

File tree

6 files changed

+62
-56
lines changed

6 files changed

+62
-56
lines changed

ne16/hal/ne16_task.c

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,18 @@ uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width, uint32_t width_stride,
113113
return ptr - (padding_top * width + padding_left) * width_stride;
114114
}
115115

116-
void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in,
117-
uint32_t w_in_stride, uint8_t padding_top,
118-
uint8_t padding_left, uint32_t output_ptr,
119-
uint32_t weights_ptr, uint32_t scale_ptr,
120-
uint32_t shift_ptr, uint32_t bias_ptr) {
116+
void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr,
117+
uint32_t w_in, uint32_t w_in_stride,
118+
uint8_t padding_top, uint8_t padding_left,
119+
uint32_t output_ptr, uint32_t weights_ptr) {
121120
task->data.infeat_ptr =
122121
ne16_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left);
123122
task->data.outfeat_ptr = output_ptr;
124123
task->data.weights_ptr = weights_ptr;
124+
}
125+
126+
void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr,
127+
uint32_t shift_ptr, uint32_t bias_ptr) {
125128
task->data.scale_ptr = scale_ptr;
126129
task->data.scale_shift_ptr = shift_ptr;
127130
task->data.scale_bias_ptr = bias_ptr;

ne16/hal/ne16_task.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,12 @@ uint32_t ne16_get_tile_padding(uint32_t padding, uint32_t i_height,
133133
uint32_t ne16_pad_ptr(uint32_t ptr, const uint32_t width,
134134
const uint32_t width_stride, const uint8_t padding_top,
135135
const uint8_t padding_left);
136-
void ne16_task_set_ptrs(ne16_task_t *task, uint32_t input_ptr, uint32_t w_in,
137-
uint32_t w_in_stride, uint8_t padding_top,
138-
uint8_t padding_left, uint32_t output_ptr,
139-
uint32_t weights_ptr, uint32_t scale_ptr,
140-
uint32_t shift_ptr, uint32_t bias_ptr);
136+
void ne16_task_set_ptrs_conv(ne16_task_t *task, uint32_t input_ptr,
137+
uint32_t w_in, uint32_t w_in_stride,
138+
uint8_t padding_top, uint8_t padding_left,
139+
uint32_t output_ptr, uint32_t weights_ptr);
140+
void ne16_task_set_ptrs_norm_quant(ne16_task_t *task, uint32_t scale_ptr,
141+
uint32_t shift_ptr, uint32_t bias_ptr);
141142
/** ne16_task_set_strides
142143
*
143144
* All the strides variables are strides between elements alongside that

neureka/hal/neureka_task.c

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,18 @@ uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width,
126126
return ptr - (padding_top * width + padding_left) * width_stride;
127127
}
128128

129-
void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr,
130-
uint32_t w_in, uint32_t w_in_stride,
131-
uint8_t padding_top, uint8_t padding_left,
132-
uint32_t output_ptr, uint32_t weights_ptr,
133-
uint32_t scale_ptr, uint32_t shift_ptr,
134-
uint32_t bias_ptr) {
129+
void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr,
130+
uint32_t w_in, uint32_t w_in_stride,
131+
uint8_t padding_top, uint8_t padding_left,
132+
uint32_t output_ptr, uint32_t weights_ptr) {
135133
task->data.infeat_ptr =
136134
neureka_pad_ptr(input_ptr, w_in, w_in_stride, padding_top, padding_left);
137135
task->data.outfeat_ptr = output_ptr;
138136
task->data.weights_ptr = weights_ptr;
137+
}
138+
139+
void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr,
140+
uint32_t shift_ptr, uint32_t bias_ptr) {
139141
task->data.scale_ptr = scale_ptr;
140142
task->data.scale_shift_ptr = shift_ptr;
141143
task->data.scale_bias_ptr = bias_ptr;

neureka/hal/neureka_task.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,12 @@ uint32_t neureka_get_tile_padding(uint32_t padding, uint32_t i_height,
142142
uint32_t neureka_pad_ptr(uint32_t ptr, const uint32_t width,
143143
const uint32_t width_stride, const uint8_t padding_top,
144144
const uint8_t padding_left);
145-
void neureka_task_set_ptrs(neureka_task_t *task, uint32_t input_ptr,
146-
uint32_t w_in, uint32_t w_in_stride,
147-
uint8_t padding_top, uint8_t padding_left,
148-
uint32_t output_ptr, uint32_t weights_ptr,
149-
uint32_t scale_ptr, uint32_t shift_ptr,
150-
uint32_t bias_ptr);
145+
void neureka_task_set_ptrs_conv(neureka_task_t *task, uint32_t input_ptr,
146+
uint32_t w_in, uint32_t w_in_stride,
147+
uint8_t padding_top, uint8_t padding_left,
148+
uint32_t output_ptr, uint32_t weights_ptr);
149+
void neureka_task_set_ptrs_norm_quant(neureka_task_t *task, uint32_t scale_ptr,
150+
uint32_t shift_ptr, uint32_t bias_ptr);
151151
/** neureka_task_set_strides
152152
*
153153
* All the strides variables are strides between elements alongside that

test/NeurekaMemoryLayout.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import numpy as np
2121
import numpy.typing as npt
2222

23-
from TestClasses import IntegerType
24-
2523

2624
class NeurekaMemoryLayout:
2725
_WEIGHT_BANDWIDTH = 256

test/app/src/nnx_layer.c

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131

3232
typedef ne16_norm_mode_e nnx_norm_mode_e;
3333
typedef ne16_quant_t nnx_quant_t;
34+
typedef ne16_quant_function_e nnx_quant_function_e;
3435
typedef ne16_norm_t nnx_norm_t;
3536
typedef ne16_task_t nnx_task_t;
3637
typedef ne16_dev_t nnx_dev_t;
3738
typedef ne16_pulp_conf_t nnx_bsp_conf_t;
39+
typedef ne16_task_flag_e nnx_task_flag_e;
3840

3941
#define nnxTaskFlagTrue ne16TaskFlagTrue
4042
#define nnxTaskFlagFalse ne16TaskFlagFalse
@@ -46,7 +48,8 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;
4648
#define nnx_task_set_weight_offset ne16_task_set_weight_offset
4749
#define nnx_task_set_dims ne16_task_set_dims
4850
#define nnx_task_set_dims_stride2x2 ne16_task_set_dims_stride2x2
49-
#define nnx_task_set_ptrs ne16_task_set_ptrs
51+
#define nnx_task_set_ptrs_conv ne16_task_set_ptrs_conv
52+
#define nnx_task_set_ptrs_norm_quant ne16_task_set_ptrs_norm_quant
5053

5154
#define NNX_GVSOC_LOG_LEVEL NE16_GVSOC_LOG_LEVEL_ALL
5255
#define NNX_GVSOC_LOG_FORMAT NE16_GVSOC_LOG_FORMAT_HEXADECIMAL
@@ -72,10 +75,12 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;
7275

7376
typedef neureka_norm_mode_e nnx_norm_mode_e;
7477
typedef neureka_quant_t nnx_quant_t;
78+
typedef neureka_quant_function_e nnx_quant_function_e;
7579
typedef neureka_norm_t nnx_norm_t;
7680
typedef neureka_task_t nnx_task_t;
7781
typedef neureka_dev_t nnx_dev_t;
7882
typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
83+
typedef neureka_task_flag_e nnx_task_flag_e;
7984

8085
#define nnxTaskFlagTrue neurekaTaskFlagTrue
8186
#define nnxTaskFlagFalse neurekaTaskFlagFalse
@@ -86,7 +91,8 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
8691
#define nnx_task_set_norm_quant neureka_task_set_norm_quant
8792
#define nnx_task_set_weight_offset neureka_task_set_weight_offset
8893
#define nnx_task_set_dims neureka_task_set_dims
89-
#define nnx_task_set_ptrs neureka_task_set_ptrs
94+
#define nnx_task_set_ptrs_conv neureka_task_set_ptrs_conv
95+
#define nnx_task_set_ptrs_norm_quant neureka_task_set_ptrs_norm_quant
9096

9197
#define NNX_GVSOC_LOG_LEVEL NEUREKA_GVSOC_LOG_LEVEL_ALL
9298
#define NNX_GVSOC_LOG_FORMAT NEUREKA_GVSOC_LOG_FORMAT_HEXADECIMAL
@@ -120,24 +126,6 @@ static void task_prepare(nnx_task_t *task) {
120126
#endif
121127
nnx_task_set_bits(task, INPUT_BITS, OUTPUT_BITS, WEIGHT_BITS);
122128

123-
#if HAS_NORM_QUANT == 1
124-
#if SCALE_BITS == 8
125-
const nnx_norm_mode_e normMode = normMode8Bit;
126-
#elif SCALE_BITS == 32
127-
const nnx_norm_mode_e normMode = normMode32Bit;
128-
#endif
129-
130-
nnx_task_set_norm_quant(
131-
task,
132-
(nnx_quant_t){.shift_amount = OUTSHIFT,
133-
.function =
134-
HAS_RELU ? quantFunctionRelu : quantFunctionIdentity,
135-
.flag_rounding = nnxTaskFlagFalse},
136-
(nnx_norm_t){.mode = normMode,
137-
.flag_bias = HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse,
138-
.flag_shift = nnxTaskFlagFalse});
139-
#endif // HAS_NORM_QUANT
140-
141129
nnx_task_set_weight_offset(task, weightOffsetModeLayerWise, WEIGHT_OFFSET);
142130

143131
#ifdef NNX_NEUREKA
@@ -171,20 +159,34 @@ static void task_prepare(nnx_task_t *task) {
171159
PADDING_RIGHT);
172160
#endif
173161

174-
nnx_task_set_ptrs(task, (uint32_t)input, INPUT_WIDTH, w_in_stride,
175-
PADDING_TOP, PADDING_LEFT, (uint32_t)output,
176-
(uint32_t)weight,
162+
nnx_task_set_ptrs_conv(task, (uint32_t)input, INPUT_WIDTH, w_in_stride,
163+
PADDING_TOP, PADDING_LEFT, (uint32_t)output,
164+
(uint32_t)weight);
165+
177166
#if HAS_NORM_QUANT == 1
178-
(uint32_t)scale, NULL,
179-
#if HAS_BIAS == 1
180-
(uint32_t)bias
181-
#else
182-
NULL
183-
#endif
184-
#else
185-
NULL, NULL, NULL
167+
#if SCALE_BITS == 8
168+
const nnx_norm_mode_e normMode = normMode8Bit;
169+
#elif SCALE_BITS == 32
170+
const nnx_norm_mode_e normMode = normMode32Bit;
186171
#endif
187-
);
172+
173+
const nnx_task_flag_e flag_bias =
174+
HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse;
175+
const uint32_t bias_ptr = (uint32_t)(HAS_BIAS ? bias : NULL);
176+
177+
nnx_quant_function_e quant_function =
178+
HAS_RELU ? quantFunctionRelu : quantFunctionIdentity;
179+
180+
nnx_task_set_norm_quant(task,
181+
(nnx_quant_t){.shift_amount = OUTSHIFT,
182+
.function = quant_function,
183+
.flag_rounding = nnxTaskFlagFalse},
184+
(nnx_norm_t){.mode = normMode,
185+
.flag_bias = flag_bias,
186+
.flag_shift = nnxTaskFlagFalse});
187+
188+
nnx_task_set_ptrs_norm_quant(task, (uint32_t)scale, NULL, bias_ptr);
189+
#endif // HAS_NORM_QUANT
188190
}
189191

190192
static void task_execute(nnx_task_t *task) {

0 commit comments

Comments
 (0)