@@ -42,8 +42,8 @@ typedef enum {
4242
4343typedef struct ne16_norm_t {
4444 ne16_norm_mode_e mode ;
45- int flag_bias ;
46- int flag_shift ;
45+ ne16_task_flag_e flag_bias ;
46+ ne16_task_flag_e flag_shift ;
4747} ne16_norm_t ;
4848
4949typedef enum ne16_quant_mode_e {
@@ -59,9 +59,9 @@ typedef enum ne16_quant_function_e {
5959
6060typedef struct ne16_quant_t {
6161 // Shift amount must be in range 0x00-0x1F
62- unsigned shift_amount ;
62+ uint8_t shift_amount ;
6363 ne16_quant_function_e function ;
64- int flag_rounding ;
64+ ne16_task_flag_e flag_rounding ;
6565} ne16_quant_t ;
6666
6767typedef struct ne16_stride_t {
@@ -133,11 +133,12 @@ uint32_t ne16_get_tile_padding(uint32_t padding, uint32_t i_height,
133133uint32_t ne16_pad_ptr (uint32_t ptr , const uint32_t width ,
134134 const uint32_t width_stride , const uint8_t padding_top ,
135135 const uint8_t padding_left );
136- void ne16_task_set_ptrs (ne16_task_t * task , uint32_t input_ptr , uint32_t w_in ,
137- uint32_t w_in_stride , uint8_t padding_top ,
138- uint8_t padding_left , uint32_t output_ptr ,
139- uint32_t weights_ptr , uint32_t scale_ptr ,
140- uint32_t shift_ptr , uint32_t bias_ptr );
136+ void ne16_task_set_ptrs_conv (ne16_task_t * task , uint32_t input_ptr ,
137+ uint32_t w_in , uint32_t w_in_stride ,
138+ uint8_t padding_top , uint8_t padding_left ,
139+ uint32_t output_ptr , uint32_t weights_ptr );
140+ void ne16_task_set_ptrs_norm_quant (ne16_task_t * task , uint32_t scale_ptr ,
141+ uint32_t shift_ptr , uint32_t bias_ptr );
141142/** ne16_task_set_strides
142143 *
143144 * All the strides variables are strides between elements alongside that
@@ -157,8 +158,8 @@ void ne16_task_set_padding(ne16_task_t *task, const uint8_t top,
157158 const uint8_t bottom , const uint8_t left ,
158159 const uint8_t right , const uint8_t value );
159160void ne16_task_set_mask_filter (ne16_task_t * task , const uint8_t top ,
160- const uint8_t right , const uint8_t bottom ,
161- const uint8_t left );
161+ const uint8_t bottom , const uint8_t left ,
162+ const uint8_t right );
162163/** ne16_task_set_dims
163164 *
164165 * All the strides variables are strides between elements alongside that
@@ -172,8 +173,8 @@ void ne16_task_set_dims(ne16_task_t *task, const uint32_t w_in,
172173 const uint32_t h_out_stride ,
173174 const uint32_t w_out_stride , const uint8_t padding_top ,
174175 const uint8_t padding_bottom ,
175- const uint8_t padding_right ,
176- const uint8_t padding_left );
176+ const uint8_t padding_left ,
177+ const uint8_t padding_right );
177178/** ne16_task_set_dims_stride2x2
178179 *
179180 * All the strides variables are strides between elements alongside that
@@ -186,7 +187,7 @@ void ne16_task_set_dims_stride2x2(
186187 const uint32_t h_out , const uint32_t w_out , const uint32_t k_out ,
187188 const uint32_t h_out_stride , const uint32_t w_out_stride ,
188189 const uint8_t h_ker , const uint8_t w_ker , const uint8_t padding_top ,
189- const uint8_t padding_bottom , const uint8_t padding_right ,
190- const uint8_t padding_left );
190+ const uint8_t padding_bottom , const uint8_t padding_left ,
191+ const uint8_t padding_right );
191192
192193#endif // !__NE16_TASK_H__
0 commit comments