Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion clarifai/client/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
nodepool_id: str = None,
deployment_id: str = None,
deployment_user_id: str = None,
validate: bool = True,
**kwargs,
):
"""Initializes a Model object.
Expand All @@ -83,12 +84,18 @@ def __init__(
nodepool_id (str): Nodepool ID for runner selector.
deployment_id (str): Deployment ID for runner selector.
deployment_user_id (str): User ID to use for runner selector (organization or user). If not provided, defaults to PAT owner user_id.
validate (bool): Whether to validate that the model exists. Defaults to True. Set to False to skip validation (used internally when model existence is already confirmed).
**kwargs: Additional keyword arguments to be passed to the Model.

Raises:
UserError: If the model URL is invalid or the model does not exist.
"""
if url and model_id:
raise UserError("You can only specify one of url or model_id.")
if not url and not model_id:
raise UserError("You must specify one of url or model_id.")

original_url = url # Store the original URL for error messages
if url:
user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(
url
Expand Down Expand Up @@ -130,6 +137,57 @@ def __init__(
deployment_user_id=deployment_user_id,
)

# Validate that the model exists (unless explicitly skipped)
if validate:
self._validate_model_exists(original_url)

def _validate_model_exists(self, original_url: str = None) -> None:
"""Validates that the model exists by making a GetModel request.

Args:
original_url (str): The original URL provided by the user (for error messages).

Raises:
UserError: If the model does not exist or cannot be accessed.
"""
try:
request = service_pb2.GetModelRequest(
user_app_id=self.user_app_id,
model_id=self.id,
version_id=self.model_info.model_version.id,
)
response = self._grpc_request(self.STUB.GetModel, request)

if response.status.code != status_code_pb2.SUCCESS:
# Model does not exist or cannot be accessed
if original_url:
error_msg = (
f"Model does not exist or cannot be accessed at URL: {original_url}\n"
f"Status: {response.status.description}\n"
f"Details: {response.status.details}"
)
else:
error_msg = (
f"Model '{self.id}' does not exist or cannot be accessed in app '{self.app_id}' "
f"for user '{self.user_id}'.\n"
f"Status: {response.status.description}\n"
f"Details: {response.status.details}"
)
raise UserError(error_msg)
except UserError:
# Re-raise UserError as-is
raise
except Exception as e:
# Handle unexpected errors during validation
if original_url:
error_msg = f"Failed to validate model at URL: {original_url}\nError: {str(e)}"
else:
error_msg = (
f"Failed to validate model '{self.id}' in app '{self.app_id}' "
f"for user '{self.user_id}'.\nError: {str(e)}"
)
raise UserError(error_msg)

@classmethod
def from_current_context(cls, **kwargs) -> 'Model':
from clarifai.urls.helper import ClarifaiUrlHelper
Expand Down Expand Up @@ -463,7 +521,7 @@ def create_version(self, **kwargs) -> 'Model':
dict_response = MessageToDict(response, preserving_proto_field_name=True)
kwargs = self.process_response_keys(dict_response['model'], 'model')

return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
return Model(base_url=self.base, pat=self.pat, token=self.token, validate=False, **kwargs)

def list_versions(
self, page_no: int = None, per_page: int = None
Expand Down Expand Up @@ -510,6 +568,7 @@ def list_versions(
yield Model.from_auth_helper(
auth=self.auth_helper,
model_id=self.id,
validate=False,
**dict(self.kwargs, model_version=model_version_info),
)

Expand Down Expand Up @@ -2075,6 +2134,7 @@ def stream_and_logging(
auth=self.auth_helper,
model_id=self.id,
model_version=dict(id=cache_uploading_info.get('model_version')),
validate=False,
)

def create_version_by_url(
Expand Down Expand Up @@ -2131,6 +2191,7 @@ def create_version_by_url(
auth=self.auth_helper,
model_id=self.id,
model_version=dict(id=response.model.model_version.id),
validate=False,
)

def patch_version(self, version_id: str, **kwargs) -> 'Model':
Expand Down Expand Up @@ -2161,4 +2222,5 @@ def patch_version(self, version_id: str, **kwargs) -> 'Model':
auth=self.auth_helper,
model_id=self.id,
model_version=dict(id=version_id),
validate=False,
)
2 changes: 1 addition & 1 deletion clarifai/workflows/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def is_dict_in_dict(d1: Dict, d2: Dict, ignore_keys: Set = None) -> bool:
if isinstance(v, dict):
if not isinstance(d2[k], dict):
return False
return is_dict_in_dict(d1[k], d2[k], None)
return is_dict_in_dict(v, d2[k], None)
elif v != d2[k]:
return False

Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
markers =
maintainer_approval: marks tests that require maintainer approval to run
requires_secrets: marks tests that require API credentials (CLARIFAI_PAT)
89 changes: 89 additions & 0 deletions tests/test_model_init_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Tests for Model initialization validation."""

import os

import pytest

from clarifai.client.model import Model
from clarifai.errors import UserError

CLARIFAI_PAT = os.environ.get("CLARIFAI_PAT", "test_pat")
CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com")

# Valid model for comparison
MAIN_APP_ID = "main"
MAIN_APP_USER_ID = "clarifai"
GENERAL_MODEL_ID = "aaa03c23b3724a16a56b629203edc62c"


class TestModelInitValidation:
"""Tests for Model constructor validation."""

@pytest.mark.requires_secrets
def test_valid_model_url(self):
"""Test that a valid model URL initializes successfully."""
url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/{GENERAL_MODEL_ID}"
model = Model(url=url, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE)
assert model.id == GENERAL_MODEL_ID
assert model.user_id == MAIN_APP_USER_ID
assert model.app_id == MAIN_APP_ID

@pytest.mark.requires_secrets
def test_valid_model_id(self):
"""Test that a valid model_id initializes successfully."""
model = Model(
user_id=MAIN_APP_USER_ID,
app_id=MAIN_APP_ID,
model_id=GENERAL_MODEL_ID,
pat=CLARIFAI_PAT,
base_url=CLARIFAI_API_BASE,
)
assert model.id == GENERAL_MODEL_ID
assert model.user_id == MAIN_APP_USER_ID
assert model.app_id == MAIN_APP_ID

@pytest.mark.requires_secrets
def test_nonexistent_model_url(self):
"""Test that a non-existent model URL raises UserError with clear message."""
url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/non-existent-model-xyz-123"
with pytest.raises(UserError) as exc_info:
Model(url=url, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE)

error_msg = str(exc_info.value)
# Check that the error message contains helpful information
assert "does not exist" in error_msg or "cannot be accessed" in error_msg
assert url in error_msg

@pytest.mark.requires_secrets
def test_nonexistent_model_id(self):
"""Test that a non-existent model_id raises UserError with clear message."""
with pytest.raises(UserError) as exc_info:
Model(
user_id=MAIN_APP_USER_ID,
app_id=MAIN_APP_ID,
model_id="non-existent-model-xyz-123",
pat=CLARIFAI_PAT,
base_url=CLARIFAI_API_BASE,
)

error_msg = str(exc_info.value)
# Check that the error message contains helpful information
assert "does not exist" in error_msg or "cannot be accessed" in error_msg
assert "non-existent-model-xyz-123" in error_msg

def test_missing_url_and_model_id(self):
"""Test that missing both url and model_id raises UserError."""
with pytest.raises(UserError) as exc_info:
Model(pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE)

error_msg = str(exc_info.value)
assert "must specify one of url or model_id" in error_msg

def test_both_url_and_model_id(self):
"""Test that providing both url and model_id raises UserError."""
url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/{GENERAL_MODEL_ID}"
with pytest.raises(UserError) as exc_info:
Model(url=url, model_id=GENERAL_MODEL_ID, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE)

error_msg = str(exc_info.value)
assert "only specify one of url or model_id" in error_msg
111 changes: 111 additions & 0 deletions tests/test_model_validation_mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Unit tests for Model initialization validation with mocks."""

from unittest.mock import MagicMock, patch

import pytest
from clarifai_grpc.grpc.api.status import status_code_pb2

from clarifai.client.model import Model
from clarifai.errors import UserError


class TestModelValidationWithMocks:
"""Test Model validation using mocks."""

@patch('clarifai.client.model.Model._validate_model_exists')
@patch('clarifai.client.model.BaseClient.__init__', return_value=None)
@patch('clarifai.client.model.Lister.__init__', return_value=None)
@patch('clarifai.client.model.Model._set_runner_selector')
def test_validation_called_by_default(
self, mock_runner_selector, mock_lister_init, mock_base_init, mock_validate
):
"""Test that validation is called by default when creating a Model."""
# Create Model with default validate=True
model = Model(model_id='test_model', user_id='test_user', app_id='test_app')

# Verify validation was called once
mock_validate.assert_called_once()

@patch('clarifai.client.model.Model._validate_model_exists')
@patch('clarifai.client.model.BaseClient.__init__', return_value=None)
@patch('clarifai.client.model.Lister.__init__', return_value=None)
@patch('clarifai.client.model.Model._set_runner_selector')
def test_validation_skipped_when_false(
self, mock_runner_selector, mock_lister_init, mock_base_init, mock_validate
):
"""Test that validation is skipped when validate=False."""
# Create Model with validate=False
model = Model(
model_id='test_model', user_id='test_user', app_id='test_app', validate=False
)

# Verify validation was NOT called
mock_validate.assert_not_called()

def test_validate_model_exists_with_success(self):
"""Test _validate_model_exists when model exists."""
model = MagicMock()
model.user_app_id = MagicMock()
model.id = 'test_model'
model.model_info.model_version.id = 'version1'

# Mock successful response
mock_response = MagicMock()
mock_response.status.code = status_code_pb2.SUCCESS

model._grpc_request = MagicMock(return_value=mock_response)

# Call the actual method
Model._validate_model_exists(model, original_url=None)

# Should not raise any exception
assert model._grpc_request.called

def test_validate_model_exists_with_failure(self):
"""Test _validate_model_exists when model does not exist."""
model = MagicMock()
model.user_app_id = MagicMock()
model.id = 'nonexistent_model'
model.app_id = 'test_app'
model.user_id = 'test_user'
model.model_info.model_version.id = ''

# Mock failure response
mock_response = MagicMock()
mock_response.status.code = status_code_pb2.MODEL_DOES_NOT_EXIST
mock_response.status.description = 'Model not found'
mock_response.status.details = 'The requested model does not exist'

model._grpc_request = MagicMock(return_value=mock_response)

# Should raise UserError
with pytest.raises(UserError) as exc_info:
Model._validate_model_exists(model, original_url=None)

error_msg = str(exc_info.value)
assert 'does not exist' in error_msg or 'cannot be accessed' in error_msg
assert 'nonexistent_model' in error_msg

def test_validate_model_exists_with_url(self):
"""Test _validate_model_exists error message includes URL when provided."""
model = MagicMock()
model.user_app_id = MagicMock()
model.id = 'test_model'
model.model_info.model_version.id = ''

# Mock failure response
mock_response = MagicMock()
mock_response.status.code = status_code_pb2.MODEL_DOES_NOT_EXIST
mock_response.status.description = 'Model not found'
mock_response.status.details = 'The requested model does not exist'

model._grpc_request = MagicMock(return_value=mock_response)

test_url = 'https://clarifai.com/test_user/test_app/models/test_model'

# Should raise UserError with URL in message
with pytest.raises(UserError) as exc_info:
Model._validate_model_exists(model, original_url=test_url)

error_msg = str(exc_info.value)
assert test_url in error_msg
Loading