diff --git a/clarifai/client/model.py b/clarifai/client/model.py index e43efabd1..1c434e229 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -67,6 +67,7 @@ def __init__( nodepool_id: str = None, deployment_id: str = None, deployment_user_id: str = None, + validate: bool = True, **kwargs, ): """Initializes a Model object. @@ -83,12 +84,18 @@ def __init__( nodepool_id (str): Nodepool ID for runner selector. deployment_id (str): Deployment ID for runner selector. deployment_user_id (str): User ID to use for runner selector (organization or user). If not provided, defaults to PAT owner user_id. + validate (bool): Whether to validate that the model exists. Defaults to True. Set to False to skip validation (used internally when model existence is already confirmed). **kwargs: Additional keyword arguments to be passed to the Model. + + Raises: + UserError: If the model URL is invalid or the model does not exist. """ if url and model_id: raise UserError("You can only specify one of url or model_id.") if not url and not model_id: raise UserError("You must specify one of url or model_id.") + + original_url = url # Store the original URL for error messages if url: user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url( url @@ -130,6 +137,57 @@ def __init__( deployment_user_id=deployment_user_id, ) + # Validate that the model exists (unless explicitly skipped) + if validate: + self._validate_model_exists(original_url) + + def _validate_model_exists(self, original_url: str = None) -> None: + """Validates that the model exists by making a GetModel request. + + Args: + original_url (str): The original URL provided by the user (for error messages). + + Raises: + UserError: If the model does not exist or cannot be accessed. + """ + try: + request = service_pb2.GetModelRequest( + user_app_id=self.user_app_id, + model_id=self.id, + version_id=self.model_info.model_version.id, + ) + response = self._grpc_request(self.STUB.GetModel, request) + + if response.status.code != status_code_pb2.SUCCESS: + # Model does not exist or cannot be accessed + if original_url: + error_msg = ( + f"Model does not exist or cannot be accessed at URL: {original_url}\n" + f"Status: {response.status.description}\n" + f"Details: {response.status.details}" + ) + else: + error_msg = ( + f"Model '{self.id}' does not exist or cannot be accessed in app '{self.app_id}' " + f"for user '{self.user_id}'.\n" + f"Status: {response.status.description}\n" + f"Details: {response.status.details}" + ) + raise UserError(error_msg) + except UserError: + # Re-raise UserError as-is + raise + except Exception as e: + # Handle unexpected errors during validation + if original_url: + error_msg = f"Failed to validate model at URL: {original_url}\nError: {str(e)}" + else: + error_msg = ( + f"Failed to validate model '{self.id}' in app '{self.app_id}' " + f"for user '{self.user_id}'.\nError: {str(e)}" + ) + raise UserError(error_msg) + @classmethod def from_current_context(cls, **kwargs) -> 'Model': from clarifai.urls.helper import ClarifaiUrlHelper @@ -463,7 +521,7 @@ def create_version(self, **kwargs) -> 'Model': dict_response = MessageToDict(response, preserving_proto_field_name=True) kwargs = self.process_response_keys(dict_response['model'], 'model') - return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs) + return Model(base_url=self.base, pat=self.pat, token=self.token, validate=False, **kwargs) def list_versions( self, page_no: int = None, per_page: int = None @@ -510,6 +568,7 @@ def list_versions( yield Model.from_auth_helper( auth=self.auth_helper, model_id=self.id, + validate=False, **dict(self.kwargs, model_version=model_version_info), ) @@ -2075,6 +2134,7 @@ def stream_and_logging( auth=self.auth_helper, model_id=self.id, model_version=dict(id=cache_uploading_info.get('model_version')), + validate=False, ) def create_version_by_url( @@ -2131,6 +2191,7 @@ def create_version_by_url( auth=self.auth_helper, model_id=self.id, model_version=dict(id=response.model.model_version.id), + validate=False, ) def patch_version(self, version_id: str, **kwargs) -> 'Model': @@ -2161,4 +2222,5 @@ def patch_version(self, version_id: str, **kwargs) -> 'Model': auth=self.auth_helper, model_id=self.id, model_version=dict(id=version_id), + validate=False, ) diff --git a/clarifai/workflows/utils.py b/clarifai/workflows/utils.py index 95e8c966a..78eb91e55 100644 --- a/clarifai/workflows/utils.py +++ b/clarifai/workflows/utils.py @@ -55,7 +55,7 @@ def is_dict_in_dict(d1: Dict, d2: Dict, ignore_keys: Set = None) -> bool: if isinstance(v, dict): if not isinstance(d2[k], dict): return False - return is_dict_in_dict(d1[k], d2[k], None) + return is_dict_in_dict(v, d2[k], None) elif v != d2[k]: return False diff --git a/pytest.ini b/pytest.ini index c6ad2dd75..f10f3e62e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] markers = maintainer_approval: marks tests that require maintainer approval to run + requires_secrets: marks tests that require API credentials (CLARIFAI_PAT) diff --git a/tests/test_model_init_validation.py b/tests/test_model_init_validation.py new file mode 100644 index 000000000..9dde024a5 --- /dev/null +++ b/tests/test_model_init_validation.py @@ -0,0 +1,89 @@ +"""Tests for Model initialization validation.""" + +import os + +import pytest + +from clarifai.client.model import Model +from clarifai.errors import UserError + +CLARIFAI_PAT = os.environ.get("CLARIFAI_PAT", "test_pat") +CLARIFAI_API_BASE = os.environ.get("CLARIFAI_API_BASE", "https://api.clarifai.com") + +# Valid model for comparison +MAIN_APP_ID = "main" +MAIN_APP_USER_ID = "clarifai" +GENERAL_MODEL_ID = "aaa03c23b3724a16a56b629203edc62c" + + +class TestModelInitValidation: + """Tests for Model constructor validation.""" + + @pytest.mark.requires_secrets + def test_valid_model_url(self): + """Test that a valid model URL initializes successfully.""" + url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/{GENERAL_MODEL_ID}" + model = Model(url=url, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE) + assert model.id == GENERAL_MODEL_ID + assert model.user_id == MAIN_APP_USER_ID + assert model.app_id == MAIN_APP_ID + + @pytest.mark.requires_secrets + def test_valid_model_id(self): + """Test that a valid model_id initializes successfully.""" + model = Model( + user_id=MAIN_APP_USER_ID, + app_id=MAIN_APP_ID, + model_id=GENERAL_MODEL_ID, + pat=CLARIFAI_PAT, + base_url=CLARIFAI_API_BASE, + ) + assert model.id == GENERAL_MODEL_ID + assert model.user_id == MAIN_APP_USER_ID + assert model.app_id == MAIN_APP_ID + + @pytest.mark.requires_secrets + def test_nonexistent_model_url(self): + """Test that a non-existent model URL raises UserError with clear message.""" + url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/non-existent-model-xyz-123" + with pytest.raises(UserError) as exc_info: + Model(url=url, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE) + + error_msg = str(exc_info.value) + # Check that the error message contains helpful information + assert "does not exist" in error_msg or "cannot be accessed" in error_msg + assert url in error_msg + + @pytest.mark.requires_secrets + def test_nonexistent_model_id(self): + """Test that a non-existent model_id raises UserError with clear message.""" + with pytest.raises(UserError) as exc_info: + Model( + user_id=MAIN_APP_USER_ID, + app_id=MAIN_APP_ID, + model_id="non-existent-model-xyz-123", + pat=CLARIFAI_PAT, + base_url=CLARIFAI_API_BASE, + ) + + error_msg = str(exc_info.value) + # Check that the error message contains helpful information + assert "does not exist" in error_msg or "cannot be accessed" in error_msg + assert "non-existent-model-xyz-123" in error_msg + + def test_missing_url_and_model_id(self): + """Test that missing both url and model_id raises UserError.""" + with pytest.raises(UserError) as exc_info: + Model(pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE) + + error_msg = str(exc_info.value) + assert "must specify one of url or model_id" in error_msg + + def test_both_url_and_model_id(self): + """Test that providing both url and model_id raises UserError.""" + url = f"https://clarifai.com/{MAIN_APP_USER_ID}/{MAIN_APP_ID}/models/{GENERAL_MODEL_ID}" + with pytest.raises(UserError) as exc_info: + Model(url=url, model_id=GENERAL_MODEL_ID, pat=CLARIFAI_PAT, base_url=CLARIFAI_API_BASE) + + error_msg = str(exc_info.value) + assert "only specify one of url or model_id" in error_msg diff --git a/tests/test_model_validation_mocks.py b/tests/test_model_validation_mocks.py new file mode 100644 index 000000000..0bd7faeba --- /dev/null +++ b/tests/test_model_validation_mocks.py @@ -0,0 +1,111 @@ +"""Unit tests for Model initialization validation with mocks.""" + +from unittest.mock import MagicMock, patch + +import pytest +from clarifai_grpc.grpc.api.status import status_code_pb2 + +from clarifai.client.model import Model +from clarifai.errors import UserError + + +class TestModelValidationWithMocks: + """Test Model validation using mocks.""" + + @patch('clarifai.client.model.Model._validate_model_exists') + @patch('clarifai.client.model.BaseClient.__init__', return_value=None) + @patch('clarifai.client.model.Lister.__init__', return_value=None) + @patch('clarifai.client.model.Model._set_runner_selector') + def test_validation_called_by_default( + self, mock_runner_selector, mock_lister_init, mock_base_init, mock_validate + ): + """Test that validation is called by default when creating a Model.""" + # Create Model with default validate=True + model = Model(model_id='test_model', user_id='test_user', app_id='test_app') + + # Verify validation was called once + mock_validate.assert_called_once() + + @patch('clarifai.client.model.Model._validate_model_exists') + @patch('clarifai.client.model.BaseClient.__init__', return_value=None) + @patch('clarifai.client.model.Lister.__init__', return_value=None) + @patch('clarifai.client.model.Model._set_runner_selector') + def test_validation_skipped_when_false( + self, mock_runner_selector, mock_lister_init, mock_base_init, mock_validate + ): + """Test that validation is skipped when validate=False.""" + # Create Model with validate=False + model = Model( + model_id='test_model', user_id='test_user', app_id='test_app', validate=False + ) + + # Verify validation was NOT called + mock_validate.assert_not_called() + + def test_validate_model_exists_with_success(self): + """Test _validate_model_exists when model exists.""" + model = MagicMock() + model.user_app_id = MagicMock() + model.id = 'test_model' + model.model_info.model_version.id = 'version1' + + # Mock successful response + mock_response = MagicMock() + mock_response.status.code = status_code_pb2.SUCCESS + + model._grpc_request = MagicMock(return_value=mock_response) + + # Call the actual method + Model._validate_model_exists(model, original_url=None) + + # Should not raise any exception + assert model._grpc_request.called + + def test_validate_model_exists_with_failure(self): + """Test _validate_model_exists when model does not exist.""" + model = MagicMock() + model.user_app_id = MagicMock() + model.id = 'nonexistent_model' + model.app_id = 'test_app' + model.user_id = 'test_user' + model.model_info.model_version.id = '' + + # Mock failure response + mock_response = MagicMock() + mock_response.status.code = status_code_pb2.MODEL_DOES_NOT_EXIST + mock_response.status.description = 'Model not found' + mock_response.status.details = 'The requested model does not exist' + + model._grpc_request = MagicMock(return_value=mock_response) + + # Should raise UserError + with pytest.raises(UserError) as exc_info: + Model._validate_model_exists(model, original_url=None) + + error_msg = str(exc_info.value) + assert 'does not exist' in error_msg or 'cannot be accessed' in error_msg + assert 'nonexistent_model' in error_msg + + def test_validate_model_exists_with_url(self): + """Test _validate_model_exists error message includes URL when provided.""" + model = MagicMock() + model.user_app_id = MagicMock() + model.id = 'test_model' + model.model_info.model_version.id = '' + + # Mock failure response + mock_response = MagicMock() + mock_response.status.code = status_code_pb2.MODEL_DOES_NOT_EXIST + mock_response.status.description = 'Model not found' + mock_response.status.details = 'The requested model does not exist' + + model._grpc_request = MagicMock(return_value=mock_response) + + test_url = 'https://clarifai.com/test_user/test_app/models/test_model' + + # Should raise UserError with URL in message + with pytest.raises(UserError) as exc_info: + Model._validate_model_exists(model, original_url=test_url) + + error_msg = str(exc_info.value) + assert test_url in error_msg