Skip to content

Commit 5094f3c

Browse files
Hakim7267JaccovG
authored andcommitted
add out el_type and el_params checks for gru and lstm
1 parent c53549e commit 5094f3c

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

lib/src/private/src/mli_check.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ bool mli_chk_containers_not_overlapped(const mli_data_container & data1, const m
251251
}
252252

253253
bool mli_chk_tensors_not_overlapped(const mli_tensor * t1, const mli_tensor * t2) {
254-
return mli_chk_containers_not_overlapped(t1->data, t2->data);
254+
return mli_chk_containers_not_overlapped(t1->data, t2->data);
255255
}
256256

257257
mli_status mli_chk_bias_frac_fx(const mli_tensor * in, const mli_tensor * weights, const mli_tensor * bias) {
@@ -2123,8 +2123,8 @@ mli_status mli_chk_rnn_dense_fx16 (
21232123
if (fail) return MLI_STATUS_TYPE_MISMATCH;
21242124

21252125
for (int idx = 0; idx < inputs_num; idx++) {
2126-
ret = MLI_CHECK_STATUS(mli_chk_out_frac_fx(in[idx], weights[idx], out), __func__);
2127-
if (ret != MLI_STATUS_OK) return ret;
2126+
ret = MLI_CHECK_STATUS(mli_chk_out_frac_fx(in[idx], weights[idx], out), __func__);
2127+
if (ret != MLI_STATUS_OK) return ret;
21282128
}
21292129

21302130
ret = MLI_CHECK_STATUS(mli_chk_bias_frac_fx(in[0], weights[0], bias), __func__);
@@ -2153,8 +2153,8 @@ mli_status mli_chk_rnn_dense_fx16_fx8_fx8 (
21532153
if (fail) return MLI_STATUS_TYPE_MISMATCH;
21542154

21552155
for (int idx = 0; idx < inputs_num; idx++) {
2156-
ret = MLI_CHECK_STATUS(mli_chk_out_frac_fx(in[idx], weights[idx], out), __func__);
2157-
if (ret != MLI_STATUS_OK) return ret;
2156+
ret = MLI_CHECK_STATUS(mli_chk_out_frac_fx(in[idx], weights[idx], out), __func__);
2157+
if (ret != MLI_STATUS_OK) return ret;
21582158
}
21592159

21602160
ret = MLI_CHECK_STATUS(mli_chk_bias_frac_fx(in[0], weights[0], bias), __func__);
@@ -2325,6 +2325,7 @@ mli_status mli_chk_lstm_cell_fx16 (
23252325
if (ret != MLI_STATUS_OK)
23262326
return ret;
23272327
if (MLI_CHECK(in->el_type == MLI_EL_FX_16, "Wrong input tensor type") ||
2328+
MLI_CHECK(out->el_type == MLI_EL_FX_16, "Wrong output tensor type") ||
23282329
MLI_CHECK(prev_out->el_type == MLI_EL_FX_16, "Wrong prev_out tensor type") ||
23292330
MLI_CHECK(weights_in->el_type == MLI_EL_FX_16, "Wrong weights_in tensor type") ||
23302331
MLI_CHECK(weights_out->el_type == MLI_EL_FX_16, "Wrong weights_out tensor type") ||
@@ -2354,6 +2355,7 @@ mli_status mli_chk_lstm_cell_fx16_fx8_fx8 (
23542355
if (ret != MLI_STATUS_OK)
23552356
return ret;
23562357
if (MLI_CHECK(in->el_type == MLI_EL_FX_16, "Wrong input tensor type") ||
2358+
MLI_CHECK(out->el_type == MLI_EL_FX_16, "Wrong output tensor type") ||
23572359
MLI_CHECK(prev_out->el_type == MLI_EL_FX_16, "Wrong prev_out tensor type") ||
23582360
MLI_CHECK(weights_in->el_type == MLI_EL_FX_8, "Wrong weights tensor type") ||
23592361
MLI_CHECK(weights_out->el_type == MLI_EL_FX_8, "Wrong weights tensor type") ||
@@ -2386,14 +2388,16 @@ mli_status mli_chk_lstm_cell_sa8_sa8_sa32(
23862388
bool fail = false;
23872389

23882390
if (MLI_CHECK(in->el_type == MLI_EL_SA_8, "Wrong input tensor type") ||
2389-
MLI_CHECK(prev_out->el_type == MLI_EL_SA_8, "Wrong prev_out tensor type") ||
2390-
MLI_CHECK(weights_in->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
2391-
MLI_CHECK(weights_out->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
2392-
MLI_CHECK(cell->el_type == MLI_EL_SA_8, "Wrong cell tensor type") ||
2393-
MLI_CHECK(bias->el_type == MLI_EL_SA_32, "Wrong bias tensor type"))
2391+
MLI_CHECK(out->el_type == MLI_EL_SA_8, "Wrong output tensor type") ||
2392+
MLI_CHECK(prev_out->el_type == MLI_EL_SA_8, "Wrong prev_out tensor type") ||
2393+
MLI_CHECK(weights_in->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
2394+
MLI_CHECK(weights_out->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
2395+
MLI_CHECK(cell->el_type == MLI_EL_SA_8, "Wrong cell tensor type") ||
2396+
MLI_CHECK(bias->el_type == MLI_EL_SA_32, "Wrong bias tensor type"))
23942397
return MLI_STATUS_TYPE_MISMATCH;
23952398

23962399
fail |= MLI_CHECK(in->el_params.sa.dim < 0, "Input tensor: Per-tensor quantization is expected");
2400+
fail |= MLI_CHECK(out->el_params.sa.dim < 0, "Output tensor: Per-tensor quantization is expected");
23972401
fail |= MLI_CHECK(prev_out->el_params.sa.dim < 0, "Prev out tensor: Per-tensor quantization is expected");
23982402
fail |= MLI_CHECK(cell->el_params.sa.dim < 0, "Cell tensor: Per-tensor quantization is expected");
23992403

@@ -2409,6 +2413,8 @@ mli_status mli_chk_lstm_cell_sa8_sa8_sa32(
24092413
if (fail) return MLI_STATUS_INCOMPATEBLE_TENSORS;
24102414
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(in, kZeroPointBitsByteRange), __func__);
24112415
if (ret != MLI_STATUS_OK) return ret;
2416+
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(out, kZeroPointBitsByteRange), __func__);
2417+
if (ret != MLI_STATUS_OK) return ret;
24122418
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(prev_out, kZeroPointBitsByteRange), __func__);
24132419
if (ret != MLI_STATUS_OK) return ret;
24142420
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(cell, kZeroPointBitsByteRange), __func__);
@@ -2532,6 +2538,7 @@ mli_status mli_chk_gru_cell_fx16 (
25322538
if (ret != MLI_STATUS_OK)
25332539
return ret;
25342540
if (MLI_CHECK(in->el_type == MLI_EL_FX_16, "Wrong input tensor type") ||
2541+
MLI_CHECK(out->el_type == MLI_EL_FX_16, "Wrong output tensor type") ||
25352542
MLI_CHECK(prev_out->el_type == MLI_EL_FX_16, "Wrong prev_out tensor type") ||
25362543
MLI_CHECK(weights_in->el_type == MLI_EL_FX_16, "Wrong weights_in tensor type") ||
25372544
MLI_CHECK(weights_out->el_type == MLI_EL_FX_16, "Wrong weights_out tensor type") ||
@@ -2559,6 +2566,7 @@ mli_status mli_chk_gru_cell_fx16_fx8_fx8 (
25592566
if (ret != MLI_STATUS_OK)
25602567
return ret;
25612568
if (MLI_CHECK(in->el_type == MLI_EL_FX_16, "Wrong input tensor type") ||
2569+
MLI_CHECK(out->el_type == MLI_EL_FX_16, "Wrong output tensor type") ||
25622570
MLI_CHECK(prev_out->el_type == MLI_EL_FX_16, "Wrong prev_out tensor type") ||
25632571
MLI_CHECK(weights_in->el_type == MLI_EL_FX_8, "Wrong weights tensor type") ||
25642572
MLI_CHECK(weights_out->el_type == MLI_EL_FX_8, "Wrong weights tensor type") ||
@@ -2590,13 +2598,15 @@ mli_status mli_chk_gru_cell_sa8_sa8_sa32(
25902598
bool fail = false;
25912599

25922600
if (MLI_CHECK(in->el_type == MLI_EL_SA_8, "Wrong input tensor type") ||
2601+
MLI_CHECK(out->el_type == MLI_EL_SA_8, "Wrong output tensor type") ||
25932602
MLI_CHECK(prev_out->el_type == MLI_EL_SA_8, "Wrong prev_out tensor type") ||
25942603
MLI_CHECK(weights_in->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
25952604
MLI_CHECK(weights_out->el_type == MLI_EL_SA_8, "Wrong weights tensor type") ||
25962605
MLI_CHECK(bias->el_type == MLI_EL_SA_32, "Wrong bias tensor type"))
25972606
return MLI_STATUS_TYPE_MISMATCH;
25982607

25992608
fail |= MLI_CHECK(in->el_params.sa.dim < 0, "Input tensor: Per-tensor quantization is expected");
2609+
fail |= MLI_CHECK(out->el_params.sa.dim < 0, "Output tensor: Per-tensor quantization is expected");
26002610
fail |= MLI_CHECK(prev_out->el_params.sa.dim < 0, "Prev out tensor: Per-tensor quantization is expected");
26012611

26022612
if (weights_in->el_params.sa.dim < 0) {
@@ -2611,6 +2621,8 @@ mli_status mli_chk_gru_cell_sa8_sa8_sa32(
26112621
if (fail) return MLI_STATUS_INCOMPATEBLE_TENSORS;
26122622
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(in, kZeroPointBitsByteRange), __func__);
26132623
if (ret != MLI_STATUS_OK) return ret;
2624+
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(out, kZeroPointBitsByteRange), __func__);
2625+
if (ret != MLI_STATUS_OK) return ret;
26142626
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(prev_out, kZeroPointBitsByteRange), __func__);
26152627
if (ret != MLI_STATUS_OK) return ret;
26162628
ret = MLI_CHECK_STATUS(mli_chk_tensor_quant_params(weights_in, kZeroPointBitsZero), __func__);
@@ -2906,7 +2918,7 @@ mli_status mli_chk_argmax(const mli_tensor *in, const mli_argmax_cfg *cfg, mli_t
29062918
return MLI_STATUS_TYPE_MISMATCH;
29072919

29082920
if (MLI_CHECK(check_layout_is_contiguous(out), "Memory Layout of out tensor must be contiguous"))
2909-
return MLI_STATUS_INCOMPATEBLE_TENSORS;
2921+
return MLI_STATUS_INCOMPATEBLE_TENSORS;
29102922

29112923
// Check if cfg is valid
29122924
if (MLI_CHECK(cfg != NULL, "Bad cfg pointer")) return MLI_STATUS_BAD_FUNC_CFG;
@@ -2922,9 +2934,9 @@ mli_status mli_chk_argmax(const mli_tensor *in, const mli_argmax_cfg *cfg, mli_t
29222934
if (cfg->axis >= 0) {
29232935
dim_size = in->shape[cfg->axis];
29242936
uint32_t slice_size = in_size / dim_size;
2925-
fail |= MLI_CHECK(cfg->topk <= (int)slice_size, "For axis >= 0 topk must be less or equal to the total number of elements in s single slice across the specified axis");
2937+
fail |= MLI_CHECK(cfg->topk <= (int)slice_size, "For axis >= 0 topk must be less or equal to the total number of elements in s single slice across the specified axis");
29262938
} else {
2927-
fail |= MLI_CHECK(cfg->topk <= (int)in_size, "For axis < 0 topk must be less or equal to the total number of elements in in");
2939+
fail |= MLI_CHECK(cfg->topk <= (int)in_size, "For axis < 0 topk must be less or equal to the total number of elements in in");
29282940
}
29292941
if (fail) return MLI_STATUS_BAD_FUNC_CFG;
29302942

0 commit comments

Comments
 (0)