3131
3232typedef ne16_norm_mode_e nnx_norm_mode_e ;
3333typedef ne16_quant_t nnx_quant_t ;
34+ typedef ne16_quant_function_e nnx_quant_function_e ;
3435typedef ne16_norm_t nnx_norm_t ;
3536typedef ne16_task_t nnx_task_t ;
3637typedef ne16_dev_t nnx_dev_t ;
3738typedef ne16_pulp_conf_t nnx_bsp_conf_t ;
39+ typedef ne16_task_flag_e nnx_task_flag_e ;
3840
3941#define nnxTaskFlagTrue ne16TaskFlagTrue
4042#define nnxTaskFlagFalse ne16TaskFlagFalse
@@ -46,7 +48,8 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;
4648#define nnx_task_set_weight_offset ne16_task_set_weight_offset
4749#define nnx_task_set_dims ne16_task_set_dims
4850#define nnx_task_set_dims_stride2x2 ne16_task_set_dims_stride2x2
49- #define nnx_task_set_ptrs ne16_task_set_ptrs
51+ #define nnx_task_set_ptrs_conv ne16_task_set_ptrs_conv
52+ #define nnx_task_set_ptrs_norm_quant ne16_task_set_ptrs_norm_quant
5053
5154#define NNX_GVSOC_LOG_LEVEL NE16_GVSOC_LOG_LEVEL_ALL
5255#define NNX_GVSOC_LOG_FORMAT NE16_GVSOC_LOG_FORMAT_HEXADECIMAL
@@ -72,10 +75,12 @@ typedef ne16_pulp_conf_t nnx_bsp_conf_t;
7275
7376typedef neureka_norm_mode_e nnx_norm_mode_e ;
7477typedef neureka_quant_t nnx_quant_t ;
78+ typedef neureka_quant_function_e nnx_quant_function_e ;
7579typedef neureka_norm_t nnx_norm_t ;
7680typedef neureka_task_t nnx_task_t ;
7781typedef neureka_dev_t nnx_dev_t ;
7882typedef neureka_siracusa_conf_t nnx_bsp_conf_t ;
83+ typedef neureka_task_flag_e nnx_task_flag_e ;
7984
8085#define nnxTaskFlagTrue neurekaTaskFlagTrue
8186#define nnxTaskFlagFalse neurekaTaskFlagFalse
@@ -86,7 +91,8 @@ typedef neureka_siracusa_conf_t nnx_bsp_conf_t;
8691#define nnx_task_set_norm_quant neureka_task_set_norm_quant
8792#define nnx_task_set_weight_offset neureka_task_set_weight_offset
8893#define nnx_task_set_dims neureka_task_set_dims
89- #define nnx_task_set_ptrs neureka_task_set_ptrs
94+ #define nnx_task_set_ptrs_conv neureka_task_set_ptrs_conv
95+ #define nnx_task_set_ptrs_norm_quant neureka_task_set_ptrs_norm_quant
9096
9197#define NNX_GVSOC_LOG_LEVEL NEUREKA_GVSOC_LOG_LEVEL_ALL
9298#define NNX_GVSOC_LOG_FORMAT NEUREKA_GVSOC_LOG_FORMAT_HEXADECIMAL
@@ -120,24 +126,6 @@ static void task_prepare(nnx_task_t *task) {
120126#endif
121127 nnx_task_set_bits (task , INPUT_BITS , OUTPUT_BITS , WEIGHT_BITS );
122128
123- #if HAS_NORM_QUANT == 1
124- #if SCALE_BITS == 8
125- const nnx_norm_mode_e normMode = normMode8Bit ;
126- #elif SCALE_BITS == 32
127- const nnx_norm_mode_e normMode = normMode32Bit ;
128- #endif
129-
130- nnx_task_set_norm_quant (
131- task ,
132- (nnx_quant_t ){.shift_amount = OUTSHIFT ,
133- .function =
134- HAS_RELU ? quantFunctionRelu : quantFunctionIdentity ,
135- .flag_rounding = nnxTaskFlagFalse },
136- (nnx_norm_t ){.mode = normMode ,
137- .flag_bias = HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse ,
138- .flag_shift = nnxTaskFlagFalse });
139- #endif // HAS_NORM_QUANT
140-
141129 nnx_task_set_weight_offset (task , weightOffsetModeLayerWise , WEIGHT_OFFSET );
142130
143131#ifdef NNX_NEUREKA
@@ -171,20 +159,34 @@ static void task_prepare(nnx_task_t *task) {
171159 PADDING_RIGHT );
172160#endif
173161
174- nnx_task_set_ptrs (task , (uint32_t )input , INPUT_WIDTH , w_in_stride ,
175- PADDING_TOP , PADDING_LEFT , (uint32_t )output ,
176- (uint32_t )weight ,
162+ nnx_task_set_ptrs_conv (task , (uint32_t )input , INPUT_WIDTH , w_in_stride ,
163+ PADDING_TOP , PADDING_LEFT , (uint32_t )output ,
164+ (uint32_t )weight );
165+
177166#if HAS_NORM_QUANT == 1
178- (uint32_t )scale , NULL ,
179- #if HAS_BIAS == 1
180- (uint32_t )bias
181- #else
182- NULL
183- #endif
184- #else
185- NULL , NULL , NULL
167+ #if SCALE_BITS == 8
168+ const nnx_norm_mode_e normMode = normMode8Bit ;
169+ #elif SCALE_BITS == 32
170+ const nnx_norm_mode_e normMode = normMode32Bit ;
186171#endif
187- );
172+
173+ const nnx_task_flag_e flag_bias =
174+ HAS_BIAS ? nnxTaskFlagTrue : nnxTaskFlagFalse ;
175+ const uint32_t bias_ptr = (uint32_t )(HAS_BIAS ? bias : NULL );
176+
177+ nnx_quant_function_e quant_function =
178+ HAS_RELU ? quantFunctionRelu : quantFunctionIdentity ;
179+
180+ nnx_task_set_norm_quant (task ,
181+ (nnx_quant_t ){.shift_amount = OUTSHIFT ,
182+ .function = quant_function ,
183+ .flag_rounding = nnxTaskFlagFalse },
184+ (nnx_norm_t ){.mode = normMode ,
185+ .flag_bias = flag_bias ,
186+ .flag_shift = nnxTaskFlagFalse });
187+
188+ nnx_task_set_ptrs_norm_quant (task , (uint32_t )scale , NULL , bias_ptr );
189+ #endif // HAS_NORM_QUANT
188190}
189191
190192static void task_execute (nnx_task_t * task ) {
0 commit comments