Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions c/src/ml-api-inference-single.c
Original file line number Diff line number Diff line change
Expand Up @@ -1997,6 +1997,30 @@ __ml_validate_model_file (const char *const *model,
return ML_ERROR_NONE;
}

/**
* @brief Internal helper to check if the file has one of the valid extensions.
* @return TRUE if valid, FALSE otherwise.
*/
static gboolean
_is_valid_extension (const char *filename, const char *const *valid_exts)
{
const char *dot;

if (!filename || !valid_exts)
return FALSE;

dot = strrchr (filename, '.');
if (!dot)
return FALSE;

for (; *valid_exts; valid_exts++) {
if (g_ascii_strcasecmp (dot, *valid_exts) == 0)
return TRUE;
}

return FALSE;
}

/**
* @brief Validates the nnfw model file.
* @since_tizen 5.5
Expand All @@ -2014,9 +2038,7 @@ _ml_validate_model_file (const char *const *model,
int status = ML_ERROR_NONE;
ml_nnfw_type_e detected = ML_NNFW_TYPE_ANY;
gboolean is_dir = FALSE;
gchar *pos, *fw_name;
gchar **file_ext = NULL;
guint i;
gchar *fw_name;

if (!nnfw)
_ml_error_report_return (ML_ERROR_INVALID_PARAMETER,
Expand Down Expand Up @@ -2060,19 +2082,6 @@ _ml_validate_model_file (const char *const *model,
goto done;
}

/* Handle mismatched case, check file extension. */
file_ext = g_malloc0 (sizeof (char *) * (num_models + 1));
for (i = 0; i < num_models; i++) {
if ((pos = strrchr (model[i], '.')) == NULL) {
_ml_error_report ("The given model [%d]=\"%s\" has invalid extension.", i,
model[i]);
status = ML_ERROR_INVALID_PARAMETER;
goto done;
}

file_ext[i] = g_ascii_strdown (pos, -1);
}

/** @todo Make sure num_models is correct for each nnfw type */
switch (*nnfw) {
case ML_NNFW_TYPE_NNFW:
Expand Down Expand Up @@ -2101,11 +2110,10 @@ _ml_validate_model_file (const char *const *model,
status = ML_ERROR_NOT_SUPPORTED;
break;
case ML_NNFW_TYPE_VD_AIFW:
if (!g_str_equal (file_ext[0], ".nb") &&
!g_str_equal (file_ext[0], ".ncp") &&
!g_str_equal (file_ext[0], ".tvn") &&
!g_str_equal (file_ext[0], ".bin")) {
status = ML_ERROR_INVALID_PARAMETER;
{
const char *exts[] = { ".nb", ".ncp", ".tvn", ".bin", NULL };
if (!_is_valid_extension (model[0], exts))
status = ML_ERROR_INVALID_PARAMETER;
}
break;
case ML_NNFW_TYPE_SNAP:
Expand All @@ -2116,20 +2124,20 @@ _ml_validate_model_file (const char *const *model,
/* SNAP requires multiple files, set supported if model file exists. */
break;
case ML_NNFW_TYPE_ARMNN:
if (!g_str_equal (file_ext[0], ".caffemodel") &&
!g_str_equal (file_ext[0], ".tflite") &&
!g_str_equal (file_ext[0], ".pb") &&
!g_str_equal (file_ext[0], ".prototxt")) {
_ml_error_report
("ARMNN accepts .caffemodel, .tflite, .pb, and .prototxt files only. Please support correct file extension. You have specified: \"%s\"",
file_ext[0]);
status = ML_ERROR_INVALID_PARAMETER;
{
const char *exts[] =
{ ".caffemodel", ".tflite", ".pb", ".prototxt", NULL };
if (!_is_valid_extension (model[0], exts)) {
_ml_error_report ("Invalid extension for ARMNN: %s", model[0]);
status = ML_ERROR_INVALID_PARAMETER;
}
}
break;
case ML_NNFW_TYPE_MXNET:
if (!g_str_equal (file_ext[0], ".params") &&
!g_str_equal (file_ext[0], ".json")) {
status = ML_ERROR_INVALID_PARAMETER;
{
const char *exts[] = { ".params", ".json", NULL };
if (!_is_valid_extension (model[0], exts))
status = ML_ERROR_INVALID_PARAMETER;
}
break;
default:
Expand All @@ -2153,6 +2161,5 @@ _ml_validate_model_file (const char *const *model,
model[0], num_models);
}

g_strfreev (file_ext);
return status;
}
Loading