Skip to content

Commit fbe2c05

Browse files
Hakim7267JaccovG
authored andcommitted
fix permute checker function
1 parent c9f065b commit fbe2c05

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

lib/src/private/src/mli_check.cc

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,26 +2689,23 @@ mli_status mli_chk_permute_sa8 (const mli_tensor * in, const mli_permute_cfg * c
26892689
if (in->el_params.sa.dim >= 0) {
26902690
bool fail = false;
26912691
if (out->el_params.sa.zero_point.mem.pi16 == in->el_params.sa.zero_point.mem.pi16)
2692-
fail |= MLI_CHECK((out->el_params.sa.zero_point.mem.pi16 == in->el_params.sa.zero_point.mem.pi16) \
2693-
&& (out->el_params.sa.scale.mem.pi16 == in->el_params.sa.scale.mem.pi16) \
2692+
fail |= MLI_CHECK((out->el_params.sa.scale.mem.pi16 == in->el_params.sa.scale.mem.pi16) \
26942693
&& (out->el_params.sa.scale_frac_bits.mem.pi8 == in->el_params.sa.scale_frac_bits.mem.pi8),
26952694
"El_params data for out tensor wasn`t initialized in a consistent way");
2696-
if (out->el_params.sa.zero_point.mem.pi16 != in->el_params.sa.zero_point.mem.pi16)
2697-
fail |= MLI_CHECK((out->el_params.sa.zero_point.mem.pi16 != in->el_params.sa.zero_point.mem.pi16) \
2698-
&& (out->el_params.sa.scale.mem.pi16 != in->el_params.sa.scale.mem.pi16) \
2699-
&& (out->el_params.sa.scale_frac_bits.mem.pi8 != in->el_params.sa.scale_frac_bits.mem.pi8),
2695+
if (out->el_params.sa.zero_point.mem.pi16 != nullptr && out->el_params.sa.zero_point.mem.pi16 != in->el_params.sa.zero_point.mem.pi16)
2696+
fail |= MLI_CHECK((out->el_params.sa.scale.mem.pi16 != nullptr && out->el_params.sa.scale.mem.pi16 != in->el_params.sa.scale.mem.pi16) \
2697+
&& (out->el_params.sa.scale_frac_bits.mem.pi8 != nullptr && out->el_params.sa.scale_frac_bits.mem.pi8 != in->el_params.sa.scale_frac_bits.mem.pi8),
27002698
"El_params data for out tensor wasn`t initialized in a consistent way");
27012699
if (out->el_params.sa.zero_point.mem.pi16 == nullptr)
2702-
fail |= MLI_CHECK((out->el_params.sa.zero_point.mem.pi16 == nullptr) \
2703-
&& (out->el_params.sa.scale.mem.pi16 == nullptr) \
2700+
fail |= MLI_CHECK((out->el_params.sa.scale.mem.pi16 == nullptr) \
27042701
&& (out->el_params.sa.scale_frac_bits.mem.pi8 == nullptr),
27052702
"El_params data for out tensor wasn`t initialized in a consistent way");
27062703

27072704
if (!fail && out->el_params.sa.zero_point.mem.pi16 != in->el_params.sa.zero_point.mem.pi16 \
27082705
&& out->el_params.sa.zero_point.mem.pi16 != nullptr)
2709-
fail |= MLI_CHECK(out->el_params.sa.zero_point.capacity >= in->el_params.sa.zero_point.capacity \
2710-
&& out->el_params.sa.scale.capacity >= in->el_params.sa.scale.capacity \
2711-
&& out->el_params.sa.scale_frac_bits.capacity >= in->el_params.sa.scale_frac_bits.capacity,
2706+
fail |= MLI_CHECK(out->el_params.sa.zero_point.capacity >= (in->shape[in->el_params.sa.dim] * sizeof(int16_t)) \
2707+
&& out->el_params.sa.scale.capacity >= (in->shape[in->el_params.sa.dim] * sizeof(int16_t)) \
2708+
&& out->el_params.sa.scale_frac_bits.capacity >= (in->shape[in->el_params.sa.dim] * sizeof(int8_t)),
27122709
"Not enough memory allocated for quantization parameters");
27132710
if (fail) return MLI_STATUS_SPEC_PARAM_MISMATCH;
27142711
}
@@ -2837,6 +2834,16 @@ mli_status mli_chk_data_movement(const mli_tensor *in, const mli_mov_cfg_t *cfg,
28372834
&& (out->el_params.sa.scale_frac_bits.mem.pi8 == nullptr),
28382835
"El_params data for out tensor wasn`t initialized in a consistent way");
28392836

2837+
//check that the configurations are valid
2838+
for (uint32_t i=0; i < in->rank; i++) {
2839+
if (MLI_CHECK((cfg->size[i] + cfg->offset[i]) <= in->shape[i] + cfg->padding_pre[i] + cfg->padding_post[i],"Size is larger than padded input"))
2840+
return MLI_STATUS_BAD_FUNC_CFG;
2841+
if (MLI_CHECK(cfg->sub_sample_step[i] > 0,"sub_sample_step should be greater than 0"))
2842+
return MLI_STATUS_BAD_FUNC_CFG;
2843+
if (MLI_CHECK(cfg->perm_dim[i] < in->rank,"permute out of range"))
2844+
return MLI_STATUS_BAD_FUNC_CFG;
2845+
}
2846+
28402847
int32_t in_dim = in->el_params.sa.dim;
28412848
int32_t out_dim = 0;
28422849
for (int dim = 0; dim < (int)in->rank; dim++) {
@@ -2870,16 +2877,6 @@ mli_status mli_chk_data_movement(const mli_tensor *in, const mli_mov_cfg_t *cfg,
28702877

28712878
}
28722879

2873-
//check that the configurations are valid
2874-
for (uint32_t i=0; i < in->rank; i++) {
2875-
if (MLI_CHECK((cfg->size[i] + cfg->offset[i]) <= in->shape[i] + cfg->padding_pre[i] + cfg->padding_post[i],"Size is larger than padded input"))
2876-
return MLI_STATUS_BAD_FUNC_CFG;
2877-
if (MLI_CHECK(cfg->sub_sample_step[i] > 0,"sub_sample_step should be greater than 0"))
2878-
return MLI_STATUS_BAD_FUNC_CFG;
2879-
if (MLI_CHECK(cfg->perm_dim[i] < in->rank,"permute out of range"))
2880-
return MLI_STATUS_BAD_FUNC_CFG;
2881-
}
2882-
28832880
//check that input and output are not overlapped
28842881
if (MLI_CHECK(mli_chk_tensors_not_overlapped(in, out),"in and out buffer must not be overlapped")) {
28852882
return MLI_STATUS_INCOMPATEBLE_TENSORS;

0 commit comments

Comments
 (0)