Skip to content

Commit 632c708

Browse files
committed
update argmax for sa32 output
1 parent 4b3f6eb commit 632c708

File tree

3 files changed

+24
-83
lines changed

3 files changed

+24
-83
lines changed

lib/src/kernels/diverse/impl/mli_krn_argmax_ref.h

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,10 @@ template <typename in_T>
197197
MLI_FORCE_INLINE void argmax_prepare_and_run(const mli_tensor *in, const mli_argmax_cfg *cfg, mli_tensor *out) {
198198

199199
/* Setting output tensor parameters based on user mli_argmax_cfg */
200-
if (out->el_type == MLI_EL_FX_8 || out->el_type == MLI_EL_FX_16) {
201-
out->el_params.fx.frac_bits = 0;
202-
}
203-
if (out->el_type == MLI_EL_SA_8 || out->el_type == MLI_EL_SA_32) {
204-
out->el_params.sa.scale.mem.i16 = 1;
205-
out->el_params.sa.zero_point.mem.i16 = 0;
206-
out->el_params.sa.scale_frac_bits.mem.i8 = 0;
207-
}
200+
out->el_params.sa.scale.mem.i16 = 1;
201+
out->el_params.sa.zero_point.mem.i16 = 0;
202+
out->el_params.sa.scale_frac_bits.mem.i8 = 0;
203+
out->el_type = MLI_EL_SA_32;
208204

209205
uint32_t dim_size = 1;
210206
if (cfg->axis >= 0)
@@ -214,13 +210,7 @@ MLI_FORCE_INLINE void argmax_prepare_and_run(const mli_tensor *in, const mli_arg
214210
out->rank = 2;
215211

216212
/* Running main argmax funtion */
217-
if (out->el_type == MLI_EL_FX_8 || out->el_type == MLI_EL_SA_8) {
218-
argmax<in_T, int8_t>(in, cfg->axis, cfg->topk, out);
219-
} else if (out->el_type == MLI_EL_FX_16) {
220-
argmax<in_T, int16_t>(in, cfg->axis, cfg->topk, out);
221-
} else if (out->el_type == MLI_EL_SA_32) {
222-
argmax<in_T, int32_t>(in, cfg->axis, cfg->topk, out);
223-
}
213+
argmax<in_T, int32_t>(in, cfg->axis, cfg->topk, out);
224214
}
225215

226216
#pragma MLI_CODE_SECTION_END()

lib/src/private/src/mli_check.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,16 +2337,10 @@ mli_status mli_chk_argmax(const mli_tensor *in, const mli_argmax_cfg *cfg, mli_t
23372337
if (MLI_CHECK(check_inner_most_dimension_is_one(in), "mem_stride of the innermost dimension for input tensor must be not more than 1."))
23382338
return MLI_STATUS_INCOMPATEBLE_TENSORS;
23392339

2340-
if (MLI_CHECK(out->el_type == MLI_EL_FX_8 || out->el_type == MLI_EL_FX_16 ||
2341-
out->el_type == MLI_EL_SA_8 || out->el_type == MLI_EL_SA_32, "Output el_type is invalid")) return MLI_STATUS_TYPE_MISMATCH;
2342-
2343-
if (MLI_CHECK(mli_prv_count_elem_num(in) <= mli_hlp_tensor_element_positive_limit(out),
2344-
"Chosen output type must be able to keep maximum index of element in flatten input tensor.")) return MLI_STATUS_TYPE_MISMATCH;
2345-
23462340
uint32_t dim_size = 1;
23472341
if (cfg->axis >= 0)
23482342
dim_size = in->shape[cfg->axis];
2349-
if (MLI_CHECK(out->data.capacity == cfg->topk * dim_size * mli_hlp_tensor_element_size(out), "Insufficient output buffer."))
2343+
if (MLI_CHECK(out->data.capacity == cfg->topk * dim_size * sizeof(int32_t), "Insufficient output buffer."))
23502344
return MLI_STATUS_NOT_ENGH_MEM;
23512345

23522346
if (in->el_type == MLI_EL_SA_8 || in->el_type == MLI_EL_SA_32)

user_tests/tests/mli_krn_argmax/tests_mli_krn_argmax.cc

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -47,29 +47,10 @@ struct argmax_test_operands {
4747
// Checksums of test tensors for various mli calculations mode.
4848
// When developer finished implementation of kernel and consider it as ok, one needs to populate
4949
// proper checksums for tests in order to highlight any change which affects results.
50-
#if defined(CRC_RM_UP)
51-
// Shared CRC Results
52-
53-
const crc32_calc test_1_chksum_fx16_fx8 { 0x23A39F6B }, test_1_chksum_fx16_sa8 { 0x23A39F6B }, test_1_chksum_fx16_fx16 { 0x0841F4C2 },
54-
test_1_chksum_fx16_sa32 { 0x02C59977 }, test_1_chksum_sa8_fx8 { 0x6341BBF5 }, test_1_chksum_sa8_sa8 { 0x6341BBF5 },
55-
test_1_chksum_sa8_fx16 { 0x1FB6A8A5 }, test_1_chksum_sa8_sa32 { 0x4AA46E3F }, test_2_chksum_fx16_fx16 { 0xAD273965 },
56-
test_2_chksum_fx16_sa32 { 0xFF8FA926 }, test_3_chksum_fx16_fx8 { 0x88B589EE }, test_3_chksum_fx16_sa8 { 0x88B589EE },
57-
test_3_chksum_sa8_fx8 { 0x2D84A0F9 }, test_3_chksum_sa8_sa8 { 0x2D84A0F9 };
58-
#elif defined(CRC_RM_CONVERGENT)
59-
60-
const crc32_calc test_1_chksum_fx16_fx8 { 0x23A39F6B }, test_1_chksum_fx16_sa8 { 0x23A39F6B }, test_1_chksum_fx16_fx16 { 0x0841F4C2 },
61-
test_1_chksum_fx16_sa32 { 0x02C59977 }, test_1_chksum_sa8_fx8 { 0x6341BBF5 }, test_1_chksum_sa8_sa8 { 0x6341BBF5 },
62-
test_1_chksum_sa8_fx16 { 0x1FB6A8A5 }, test_1_chksum_sa8_sa32 { 0x4AA46E3F }, test_2_chksum_fx16_fx16 { 0xAD273965 },
63-
test_2_chksum_fx16_sa32 { 0xFF8FA926 }, test_3_chksum_fx16_fx8 { 0x88B589EE }, test_3_chksum_fx16_sa8 { 0x88B589EE },
64-
test_3_chksum_sa8_fx8 { 0x2D84A0F9 }, test_3_chksum_sa8_sa8 { 0x2D84A0F9 };
65-
#else // Not defined CRC_*
66-
67-
const crc32_calc test_1_chksum_fx16_fx8, test_1_chksum_fx16_sa8, test_1_chksum_fx16_fx16,
68-
test_1_chksum_fx16_sa32, test_1_chksum_sa8_fx8, test_1_chksum_sa8_sa8,
69-
test_1_chksum_sa8_fx16 , test_1_chksum_sa8_sa32, test_2_chksum_fx16_fx16,
70-
test_2_chksum_fx16_sa32, test_3_chksum_sa32_fx8, test_3_chksum_sa32_fp32,
71-
test_3_chksum_fx8_sa8, test_3_chksum_fx8_sa32;
72-
#endif
50+
51+
const crc32_calc test_1_chksum_sa8_sa32 { 0x4AA46E3F }, test_1_chksum_fx16_sa32 { 0x02C59977 },
52+
test_2_chksum_sa8_sa32 { 0xB1469CC8 }, test_2_chksum_fx16_sa32 { 0xFF8FA926 },
53+
test_3_chksum_sa8_sa32 { 0xCD9EBC45 }, test_3_chksum_fx16_sa32 { 0x5D2A6837 };
7354

7455
const quality_metrics thresholds_test_1_general{ quality_metrics::kPassValueMaxAbsErr, quality_metrics::kPassValueSnr,
7556
/* SNR_DB = */0.0f, quality_metrics::kPassValueQuantErrPerc };
@@ -81,48 +62,24 @@ const quality_metrics thresholds_test_2_3_general{ quality_metrics::kPassValueMa
8162
static const argmax_test_operands tests_list[] = {
8263

8364
// Basic functionality test
84-
{"Test 1 FX16 - FX8 (1 elem)", mli_krn_argmax_fx16,
85-
input_1_fx16, test_1_out_fx8, test_1_cfg,
86-
thresholds_test_1_general, test_1_chksum_fx16_fx8},
87-
{"Test 1 FX16 - SA8 (1 elem)", mli_krn_argmax_fx16,
88-
input_1_fx16, test_1_out_sa8, test_1_cfg,
89-
thresholds_test_1_general, test_1_chksum_fx16_sa8},
90-
{"Test 1 FX16 - FX16 (1 elem)", mli_krn_argmax_fx16,
91-
input_1_fx16, test_1_out_fx16, test_1_cfg,
92-
thresholds_test_1_general, test_1_chksum_fx16_fx16},
93-
{"Test 1 FX16 - SA32 (1 elem)", mli_krn_argmax_fx16,
94-
input_1_fx16, test_1_out_sa32, test_1_cfg,
95-
thresholds_test_1_general, test_1_chksum_fx16_sa32},
96-
{"Test 1 SA8 - FX8 (1 elem)", mli_krn_argmax_sa8,
97-
input_1_sa8, test_1_out_fx8, test_1_cfg,
98-
thresholds_test_1_general, test_1_chksum_sa8_fx8},
99-
{"Test 1 SA8 - SA8 (1 elem)", mli_krn_argmax_sa8,
100-
input_1_sa8, test_1_out_sa8, test_1_cfg,
101-
thresholds_test_1_general, test_1_chksum_sa8_sa8},
102-
{"Test 1 SA8 - FX16 (1 elem)", mli_krn_argmax_sa8,
103-
input_1_sa8, test_1_out_fx16, test_1_cfg,
104-
thresholds_test_1_general, test_1_chksum_sa8_fx16},
105-
{"Test 1 SA8 - SA32 (1 elem)", mli_krn_argmax_sa8,
65+
{"Test 1 SA8 - SA32 (1 elem)", mli_krn_argmax_sa8,
10666
input_1_sa8, test_1_out_sa32, test_1_cfg,
10767
thresholds_test_1_general, test_1_chksum_sa8_sa32},
108-
{"Test 2 FX16 - FX16 (144 elem)", mli_krn_argmax_fx16,
109-
input_2_fx16, test_2_out_fx16, test_2_cfg,
110-
thresholds_test_2_3_general, test_2_chksum_fx16_fx16},
111-
{"Test 2 FX16 - SA32 (144 elem)", mli_krn_argmax_fx16,
68+
{"Test 1 FX16 - SA32 (1 elem)", mli_krn_argmax_fx16,
69+
input_1_fx16, test_1_out_sa32, test_1_cfg,
70+
thresholds_test_1_general, test_1_chksum_fx16_sa32},
71+
{"Test 2 SA8 - SA32 (144 elem)", mli_krn_argmax_sa8,
72+
input_2_sa8, test_2_out_sa32, test_2_cfg,
73+
thresholds_test_2_3_general, test_2_chksum_sa8_sa32},
74+
{"Test 2 FX16 - SA32 (144 elem)", mli_krn_argmax_fx16,
11275
input_2_fx16, test_2_out_sa32, test_2_cfg,
11376
thresholds_test_2_3_general, test_2_chksum_fx16_sa32},
114-
{"Test 3 FX16 - FX8 (axis = 2)", mli_krn_argmax_fx16,
115-
input_3_fx16, test_3_out_fx8, test_3_cfg,
116-
thresholds_test_2_3_general, test_3_chksum_fx16_fx8},
117-
{"Test 3 FX16 - SA8 (axis = 2)", mli_krn_argmax_fx16,
118-
input_3_fx16, test_3_out_sa8, test_3_cfg,
119-
thresholds_test_2_3_general, test_3_chksum_fx16_sa8},
120-
{"Test 3 SA8 - FX8 (axis = 2)", mli_krn_argmax_sa8,
121-
input_3_sa8, test_3_out_fx8, test_3_cfg,
122-
thresholds_test_2_3_general, test_3_chksum_sa8_fx8},
123-
{"Test 3 SA8 - SA8 (axis = 2)", mli_krn_argmax_sa8,
124-
input_3_sa8, test_3_out_sa8, test_3_cfg,
125-
thresholds_test_2_3_general, test_3_chksum_sa8_sa8}
77+
{"Test 3 SA8 - SA32 (axis = 2)", mli_krn_argmax_sa8,
78+
input_3_sa8, test_3_out_sa32, test_3_cfg,
79+
thresholds_test_2_3_general, test_3_chksum_sa8_sa32},
80+
{"Test 3 FX16 - SA32 (axis = 2)", mli_krn_argmax_fx16,
81+
input_3_fx16, test_3_out_sa32, test_3_cfg,
82+
thresholds_test_2_3_general, test_3_chksum_fx16_sa32}
12683
};
12784

12885
constexpr int kMemSize = 10000;

0 commit comments

Comments
 (0)