diff --git a/c/src/ml-api-inference-single.c b/c/src/ml-api-inference-single.c index 03df6326..c6638f09 100644 --- a/c/src/ml-api-inference-single.c +++ b/c/src/ml-api-inference-single.c @@ -1997,6 +1997,30 @@ __ml_validate_model_file (const char *const *model, return ML_ERROR_NONE; } +/** + * @brief Internal helper to check if the file has one of the valid extensions. + * @return TRUE if valid, FALSE otherwise. + */ +static gboolean +_is_valid_extension (const char *filename, const char *const *valid_exts) +{ + const char *dot; + + if (!filename || !valid_exts) + return FALSE; + + dot = strrchr (filename, '.'); + if (!dot) + return FALSE; + + for (; *valid_exts; valid_exts++) { + if (g_ascii_strcasecmp (dot, *valid_exts) == 0) + return TRUE; + } + + return FALSE; +} + /** * @brief Validates the nnfw model file. * @since_tizen 5.5 @@ -2014,9 +2038,7 @@ _ml_validate_model_file (const char *const *model, int status = ML_ERROR_NONE; ml_nnfw_type_e detected = ML_NNFW_TYPE_ANY; gboolean is_dir = FALSE; - gchar *pos, *fw_name; - gchar **file_ext = NULL; - guint i; + gchar *fw_name; if (!nnfw) _ml_error_report_return (ML_ERROR_INVALID_PARAMETER, @@ -2060,19 +2082,6 @@ _ml_validate_model_file (const char *const *model, goto done; } - /* Handle mismatched case, check file extension. */ - file_ext = g_malloc0 (sizeof (char *) * (num_models + 1)); - for (i = 0; i < num_models; i++) { - if ((pos = strrchr (model[i], '.')) == NULL) { - _ml_error_report ("The given model [%d]=\"%s\" has invalid extension.", i, - model[i]); - status = ML_ERROR_INVALID_PARAMETER; - goto done; - } - - file_ext[i] = g_ascii_strdown (pos, -1); - } - /** @todo Make sure num_models is correct for each nnfw type */ switch (*nnfw) { case ML_NNFW_TYPE_NNFW: @@ -2101,11 +2110,10 @@ _ml_validate_model_file (const char *const *model, status = ML_ERROR_NOT_SUPPORTED; break; case ML_NNFW_TYPE_VD_AIFW: - if (!g_str_equal (file_ext[0], ".nb") && - !g_str_equal (file_ext[0], ".ncp") && - !g_str_equal (file_ext[0], ".tvn") && - !g_str_equal (file_ext[0], ".bin")) { - status = ML_ERROR_INVALID_PARAMETER; + { + const char *exts[] = { ".nb", ".ncp", ".tvn", ".bin", NULL }; + if (!_is_valid_extension (model[0], exts)) + status = ML_ERROR_INVALID_PARAMETER; } break; case ML_NNFW_TYPE_SNAP: @@ -2116,20 +2124,20 @@ _ml_validate_model_file (const char *const *model, /* SNAP requires multiple files, set supported if model file exists. */ break; case ML_NNFW_TYPE_ARMNN: - if (!g_str_equal (file_ext[0], ".caffemodel") && - !g_str_equal (file_ext[0], ".tflite") && - !g_str_equal (file_ext[0], ".pb") && - !g_str_equal (file_ext[0], ".prototxt")) { - _ml_error_report - ("ARMNN accepts .caffemodel, .tflite, .pb, and .prototxt files only. Please support correct file extension. You have specified: \"%s\"", - file_ext[0]); - status = ML_ERROR_INVALID_PARAMETER; + { + const char *exts[] = + { ".caffemodel", ".tflite", ".pb", ".prototxt", NULL }; + if (!_is_valid_extension (model[0], exts)) { + _ml_error_report ("Invalid extension for ARMNN: %s", model[0]); + status = ML_ERROR_INVALID_PARAMETER; + } } break; case ML_NNFW_TYPE_MXNET: - if (!g_str_equal (file_ext[0], ".params") && - !g_str_equal (file_ext[0], ".json")) { - status = ML_ERROR_INVALID_PARAMETER; + { + const char *exts[] = { ".params", ".json", NULL }; + if (!_is_valid_extension (model[0], exts)) + status = ML_ERROR_INVALID_PARAMETER; } break; default: @@ -2153,6 +2161,5 @@ _ml_validate_model_file (const char *const *model, model[0], num_models); } - g_strfreev (file_ext); return status; }