diff --git a/recipes/test_connections_domain_namespace.py b/recipes/test_connections_domain_namespace.py deleted file mode 100644 index b184cdf56..000000000 --- a/recipes/test_connections_domain_namespace.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -"""Demo script for the ConnectionsNamespace domain API. - -This script demonstrates how to use the new Connections domain namespace -to manage cloud storage connections in Kili projects. - -Note: This is a demonstration script. In real usage, you would need: -- Valid API credentials -- Existing cloud storage integrations -- Valid project IDs -""" - - -def demo_connections_namespace(): - """Demonstrate the Connections domain namespace functionality.""" - print("๐Ÿ”— Kili ConnectionsNamespace Demo") - print("=" * 50) - - # Initialize Kili client (would need real API key in practice) - print("\n1. Initializing Kili client...") - # kili = Kili(api_key="your-api-key-here") - - # For demo purposes, we'll show the API structure - print(" โœ“ Client initialized with connections namespace available") - print(" Access via: kili.connections or kili.connections (in non-legacy mode)") - - print("\n2. Available Operations:") - print(" ๐Ÿ“‹ list() - Query and list cloud storage connections") - print(" โž• add() - Connect cloud storage integration to project") - print(" ๐Ÿ”„ sync() - Synchronize connection with cloud storage") - - print("\n3. Example Usage Patterns:") - - print("\n ๐Ÿ“‹ List connections for a project:") - print(" ```python") - print(" connections = kili.connections.list(project_id='project_123')") - print(" print(f'Found {len(connections)} connections')") - print(" ```") - - print("\n โž• Add a new connection with filtering:") - print(" ```python") - print(" result = kili.connections.add(") - print(" project_id='project_123',") - print(" cloud_storage_integration_id='integration_456',") - print(" prefix='data/images/',") - print(" include=['*.jpg', '*.png'],") - print(" exclude=['**/temp/*']") - print(" )") - print(" connection_id = result['id']") - print(" ```") - - print("\n ๐Ÿ”„ Synchronize connection (with dry-run preview):") - print(" ```python") - print(" # Preview changes first") - print(" preview = kili.connections.sync(") - print(" connection_id='connection_789',") - print(" dry_run=True") - print(" )") - print(" ") - print(" # Apply changes") - print(" result = kili.connections.sync(") - print(" connection_id='connection_789',") - print(" delete_extraneous_files=False") - print(" )") - print(" print(f'Synchronized {result[\"numberOfAssets\"]} assets')") - print(" ```") - - print("\n4. Key Features:") - print(" ๐ŸŽฏ Simplified API focused on connections (vs general cloud storage)") - print(" ๐Ÿ›ก๏ธ Enhanced error handling with user-friendly messages") - print(" โœ… Input validation for required parameters") - print(" ๐Ÿ“Š Comprehensive type hints and documentation") - print(" ๐Ÿ”„ Lazy loading and memory optimizations via base class") - print(" ๐Ÿงช Dry-run support for safe synchronization testing") - - print("\n5. Integration Benefits:") - print(" โ€ข Clean separation: connections vs cloud storage integrations") - print(" โ€ข Consistent API patterns across all domain namespaces") - print(" โ€ข Better discoverability through focused namespace") - print(" โ€ข Enhanced user experience for cloud storage workflows") - - print("\nโœจ ConnectionsNamespace Demo Complete!") - print("=" * 50) - - -if __name__ == "__main__": - demo_connections_namespace() diff --git a/src/kili/data/__init__.py b/src/kili/data/__init__.py new file mode 100644 index 000000000..09f63103c --- /dev/null +++ b/src/kili/data/__init__.py @@ -0,0 +1 @@ +"""Data layer for repository pattern implementation.""" diff --git a/src/kili/domain_api/assets.py b/src/kili/domain_api/assets.py index cf2f01608..4c100098a 100644 --- a/src/kili/domain_api/assets.py +++ b/src/kili/domain_api/assets.py @@ -9,7 +9,6 @@ Dict, Generator, List, - Literal, Optional, Union, cast, @@ -30,6 +29,15 @@ from kili.domain.project import ProjectId, ProjectStep, WorkflowVersion from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.asset import ( + AssetCreateResponse, + AssetView, + WorkflowStepResponse, + validate_asset, + validate_asset_create_response, + validate_workflow_step_response, +) +from kili.domain_v2.project import IdListResponse, IdResponse from kili.presentation.client.helpers.common_validators import ( disable_tqdm_if_as_generator, ) @@ -72,7 +80,7 @@ def invalidate( asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, project_id: Optional[str] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[WorkflowStepResponse]: """Send assets back to queue (invalidate current step). This method sends assets back to the queue, effectively invalidating their @@ -84,19 +92,23 @@ def invalidate( project_id: The project ID. Only required if `external_ids` argument is provided. Returns: - A dict object with the project `id` and the `asset_ids` of assets moved to queue. + A response object with the project `id` and the `asset_ids` of assets moved to queue. + Returns None if no assets have changed status. An error message if mutation failed. Examples: - >>> kili.assets.workflow.step.invalidate( + >>> result = kili.assets.workflow.step.invalidate( asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"] ) + >>> print(result.id) # Project ID + >>> print(result.asset_ids) # List of invalidated asset IDs """ - return self._assets_namespace.client.send_back_to_queue( + result = self._assets_namespace.client.send_back_to_queue( asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, ) + return WorkflowStepResponse(validate_workflow_step_response(result)) if result else None @typechecked def next( @@ -104,7 +116,7 @@ def next( asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, project_id: Optional[str] = None, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[WorkflowStepResponse]: """Move assets to the next workflow step (typically review). This method moves assets to the next step in the workflow, typically @@ -116,20 +128,23 @@ def next( project_id: The project ID. Only required if `external_ids` argument is provided. Returns: - A dict object with the project `id` and the `asset_ids` of assets moved to review. - `None` if no assets have changed status (already had `TO_REVIEW` status for example). + A response object with the project `id` and the `asset_ids` of assets moved to review. + Returns None if no assets have changed status (already had `TO_REVIEW` status for example). An error message if mutation failed. Examples: - >>> kili.assets.workflow.step.next( + >>> result = kili.assets.workflow.step.next( asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"] ) + >>> print(result.id) # Project ID + >>> print(result.asset_ids) # List of assets moved to review """ - return self._assets_namespace.client.add_to_review( + result = self._assets_namespace.client.add_to_review( asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, ) + return WorkflowStepResponse(validate_workflow_step_response(result)) if result else None class WorkflowNamespace: @@ -159,7 +174,7 @@ def assign( asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, project_id: Optional[str] = None, - ) -> List[Dict[str, Any]]: + ) -> IdListResponse: """Assign a list of assets to a list of labelers. Args: @@ -169,21 +184,23 @@ def assign( to_be_labeled_by_array: The array of list of labelers to assign per labelers (list of userIds). Returns: - A list of dictionaries with the asset ids. + A response object containing the list of assigned asset IDs. Examples: - >>> kili.assets.workflow.assign( + >>> result = kili.assets.workflow.assign( asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"], to_be_labeled_by_array=[['cm3yja6kv0i698697gcil9rtk','cm3yja6kv0i000000gcil9rtk'], ['cm3yja6kv0i698697gcil9rtk']] ) + >>> print(result.ids) # List of assigned asset IDs """ - return self._assets_namespace.client.assign_assets_to_labelers( + result = self._assets_namespace.client.assign_assets_to_labelers( asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, to_be_labeled_by_array=to_be_labeled_by_array, ) + return IdListResponse(result) class ExternalIdsNamespace: @@ -204,7 +221,7 @@ def update( asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, project_id: Optional[str] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Update the external IDs of one or more assets. Args: @@ -214,20 +231,22 @@ def update( project_id: The project ID. Only required if `external_ids` argument is provided. Returns: - A list of dictionaries with the asset ids. + A response object containing the list of updated asset IDs. Examples: - >>> kili.assets.external_ids.update( + >>> result = kili.assets.external_ids.update( new_external_ids=["asset1", "asset2"], asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"], ) + >>> print(result.ids) # List of updated asset IDs """ - return self._assets_namespace.client.change_asset_external_ids( + result = self._assets_namespace.client.change_asset_external_ids( new_external_ids=new_external_ids, asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, ) + return IdListResponse(result) class MetadataNamespace: @@ -248,7 +267,7 @@ def add( project_id: str, asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Add metadata to assets without overriding existing metadata. Args: @@ -259,10 +278,10 @@ def add( external_ids: The external asset IDs to modify (if `asset_ids` is not already provided). Returns: - A list of dictionaries with the asset ids. + A response object containing the list of modified asset IDs. Examples: - >>> kili.assets.metadata.add( + >>> result = kili.assets.metadata.add( json_metadata=[ {"key1": "value1", "key2": "value2"}, {"key3": "value3"} @@ -270,13 +289,15 @@ def add( project_id="cm92to3cx012u7l0w6kij9qvx", asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"] ) + >>> print(result.ids) # List of modified asset IDs """ - return self._assets_namespace.client.add_metadata( + result = self._assets_namespace.client.add_metadata( json_metadata=json_metadata, project_id=project_id, asset_ids=asset_ids, external_ids=external_ids, ) + return IdListResponse(result) @typechecked def set( @@ -285,7 +306,7 @@ def set( project_id: str, asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Set metadata on assets, replacing any existing metadata. Args: @@ -296,10 +317,10 @@ def set( external_ids: The external asset IDs to modify (if `asset_ids` is not already provided). Returns: - A list of dictionaries with the asset ids. + A response object containing the list of modified asset IDs. Examples: - >>> kili.assets.metadata.set( + >>> result = kili.assets.metadata.set( json_metadata=[ {"key1": "value1", "key2": "value2"}, {"key3": "value3"} @@ -307,13 +328,15 @@ def set( project_id="cm92to3cx012u7l0w6kij9qvx", asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"] ) + >>> print(result.ids) # List of modified asset IDs """ - return self._assets_namespace.client.set_metadata( + result = self._assets_namespace.client.set_metadata( json_metadata=json_metadata, project_id=project_id, asset_ids=asset_ids, external_ids=external_ids, ) + return IdListResponse(result) class AssetsNamespace(DomainNamespace): @@ -551,7 +574,7 @@ def list( disable_tqdm: Optional[bool] = None, as_generator: bool = True, **kwargs, - ) -> Union[Generator[Dict, None, None], List[Dict], "pd.DataFrame"]: + ) -> Union[Generator[AssetView, None, None], List[AssetView], "pd.DataFrame"]: """List assets from a project. Args: @@ -645,7 +668,7 @@ def list( ) if as_generator: - return assets_gen + return (AssetView(validate_asset(item)) for item in assets_gen) assets_list = list(assets_gen) @@ -659,7 +682,7 @@ def list( "pandas not available, returning list instead", ImportWarning, stacklevel=2 ) - return assets_list + return [AssetView(validate_asset(item)) for item in assets_list] @typechecked def count( @@ -706,7 +729,7 @@ def create( from_csv: Optional[str] = None, csv_separator: str = ",", **kwargs, - ) -> Dict[Literal["id", "asset_ids"], Union[str, List[str]]]: + ) -> AssetCreateResponse: """Create assets in a project. Args: @@ -724,7 +747,7 @@ def create( **kwargs: Additional arguments Returns: - A dictionary with project id and list of created asset ids + A response object with project id and list of created asset ids. Examples: >>> # Create image assets @@ -732,6 +755,8 @@ def create( ... project_id="my_project", ... content_array=["https://example.com/image.png"] ... ) + >>> print(result.id) # Project ID + >>> print(result.asset_ids) # List of created asset IDs >>> # Create assets with metadata >>> result = kili.assets.create( @@ -741,7 +766,7 @@ def create( ... ) """ # Call the legacy method directly through the client - return self.client.append_many_to_dataset( + result = self.client.append_many_to_dataset( project_id=project_id, content_array=content_array, multi_layer_content_array=multi_layer_content_array, @@ -755,6 +780,7 @@ def create( csv_separator=csv_separator, **kwargs, ) + return AssetCreateResponse(validate_asset_create_response(cast(Dict[str, Any], result))) @typechecked def delete( @@ -762,7 +788,7 @@ def delete( asset_ids: Optional[List[str]] = None, external_ids: Optional[List[str]] = None, project_id: Optional[str] = None, - ) -> Optional[Dict[Literal["id"], str]]: + ) -> Optional[IdResponse]: """Delete assets from a project. Args: @@ -771,13 +797,14 @@ def delete( project_id: The project ID. Only required if `external_ids` argument is provided Returns: - A dict object with the project `id` + A response object with the project `id`, or None if no deletion occurred. Examples: >>> # Delete assets by internal IDs >>> result = kili.assets.delete( ... asset_ids=["ckg22d81r0jrg0885unmuswj8", "ckg22d81s0jrh0885pdxfd03n"] ... ) + >>> print(result.id) # Project ID >>> # Delete assets by external IDs >>> result = kili.assets.delete( @@ -786,11 +813,12 @@ def delete( ... ) """ # Call the legacy method directly through the client - return self.client.delete_many_from_dataset( + result = self.client.delete_many_from_dataset( asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, ) + return IdResponse(result) if result else None @typechecked def update( @@ -807,7 +835,7 @@ def update( is_used_for_consensus_array: Optional[List[bool]] = None, is_honeypot_array: Optional[List[bool]] = None, **kwargs, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Update the properties of one or more assets. Args: @@ -825,7 +853,7 @@ def update( **kwargs: Additional update parameters Returns: - A list of dictionaries with the asset ids + A response object containing the list of updated asset IDs. Examples: >>> # Update asset priorities and metadata @@ -834,6 +862,7 @@ def update( ... priorities=[1], ... json_metadatas=[{"updated": True}] ... ) + >>> print(result.ids) # List of updated asset IDs >>> # Update honeypot settings >>> result = kili.assets.update( @@ -843,7 +872,7 @@ def update( ... ) """ # Call the legacy method directly through the client - return self.client.update_properties_in_assets( + result = self.client.update_properties_in_assets( asset_ids=asset_ids, external_ids=external_ids, project_id=project_id, @@ -857,3 +886,4 @@ def update( is_honeypot_array=is_honeypot_array, **kwargs, ) + return IdListResponse(result) diff --git a/src/kili/domain_api/connections.py b/src/kili/domain_api/connections.py index 411328be4..359367f30 100644 --- a/src/kili/domain_api/connections.py +++ b/src/kili/domain_api/connections.py @@ -1,11 +1,13 @@ """Connections domain namespace for the Kili Python SDK.""" -from typing import Dict, Generator, Iterable, List, Literal, Optional, overload +from typing import Generator, Iterable, List, Literal, Optional, overload from typeguard import typechecked from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.connection import ConnectionView, validate_connection +from kili.domain_v2.project import IdResponse from kili.presentation.client.cloud_storage import CloudStorageClientMethods @@ -69,7 +71,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[ConnectionView, None, None]: ... @overload @@ -90,7 +92,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[ConnectionView]: ... @typechecked @@ -111,7 +113,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[ConnectionView]: """Get a generator or a list of cloud storage connections that match a set of criteria. This method provides a simplified interface for querying cloud storage connections, @@ -169,7 +171,7 @@ def list( ... ) """ # Access the legacy method directly by calling it from the mixin class - return CloudStorageClientMethods.cloud_storage_connections( + result = CloudStorageClientMethods.cloud_storage_connections( self.client, cloud_storage_connection_id=connection_id, cloud_storage_integration_id=cloud_storage_integration_id, @@ -181,6 +183,20 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with ConnectionView + if as_generator: + # Create intermediate generator - iter() makes result explicitly iterable + def _wrap_generator() -> Generator[ConnectionView, None, None]: + result_iter = iter(result) + for item in result_iter: + yield ConnectionView(validate_connection(item)) + + return _wrap_generator() + + # Convert to list - list() makes result explicitly iterable + result_list = list(result) + return [ConnectionView(validate_connection(item)) for item in result_list] + @typechecked def add( self, @@ -190,7 +206,7 @@ def add( prefix: Optional[str] = None, include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, - ) -> Dict: + ) -> ConnectionView: """Connect a cloud storage integration to a project. This method creates a new connection between a cloud storage integration and a project, @@ -211,7 +227,7 @@ def add( Files matching any of these patterns will be excluded. Returns: - A dictionary containing the ID of the created connection. + A ConnectionView object representing the newly created connection. Raises: ValueError: If project_id or cloud_storage_integration_id are invalid. @@ -220,20 +236,20 @@ def add( Examples: >>> # Basic connection setup - >>> result = kili.connections.add( + >>> connection = kili.connections.add( ... project_id="project_123", ... cloud_storage_integration_id="integration_456" ... ) >>> # Connect with path prefix filter - >>> result = kili.connections.add( + >>> connection = kili.connections.add( ... project_id="project_123", ... cloud_storage_integration_id="integration_456", ... prefix="datasets/training/" ... ) >>> # Connect with include/exclude patterns - >>> result = kili.connections.add( + >>> connection = kili.connections.add( ... project_id="project_123", ... cloud_storage_integration_id="integration_456", ... include=["*.jpg", "*.png", "*.jpeg"], @@ -241,7 +257,7 @@ def add( ... ) >>> # Advanced filtering combination - >>> result = kili.connections.add( + >>> connection = kili.connections.add( ... project_id="project_123", ... cloud_storage_integration_id="integration_456", ... prefix="data/images/", @@ -249,8 +265,10 @@ def add( ... exclude=["*/thumbnails/*"] ... ) - >>> # Access the connection ID - >>> connection_id = result["id"] + >>> # Access connection properties + >>> connection_id = connection.id + >>> num_assets = connection.number_of_assets + >>> project = connection.project_id """ # Validate input parameters if not project_id or not project_id.strip(): @@ -261,7 +279,7 @@ def add( # Access the legacy method directly by calling it from the mixin class try: - return CloudStorageClientMethods.add_cloud_storage_connection( + result = CloudStorageClientMethods.add_cloud_storage_connection( self.client, project_id=project_id, cloud_storage_integration_id=cloud_storage_integration_id, @@ -270,6 +288,7 @@ def add( include=include, exclude=exclude, ) + return ConnectionView(validate_connection(result)) except Exception as e: # Enhance error messaging for connection failures if "not found" in str(e).lower(): @@ -293,7 +312,7 @@ def sync( connection_id: str, delete_extraneous_files: bool = False, dry_run: bool = False, - ) -> Dict: + ) -> IdResponse: """Synchronize a cloud storage connection. This method synchronizes the specified cloud storage connection by computing @@ -308,8 +327,7 @@ def sync( Useful for previewing what changes would be made before applying them. Returns: - A dictionary containing connection information after synchronization, - including the number of assets and project ID. + An IdResponse object containing the connection ID after synchronization. Raises: ValueError: If connection_id is invalid or empty. @@ -333,9 +351,8 @@ def sync( ... dry_run=False ... ) - >>> # Check results - >>> assets_count = result["numberOfAssets"] - >>> project_id = result["projectId"] + >>> # Access the connection ID + >>> connection_id = result.id """ # Validate input parameters if not connection_id or not connection_id.strip(): @@ -343,12 +360,13 @@ def sync( # Access the legacy method directly by calling it from the mixin class try: - return CloudStorageClientMethods.synchronize_cloud_storage_connection( + result = CloudStorageClientMethods.synchronize_cloud_storage_connection( self.client, cloud_storage_connection_id=connection_id, delete_extraneous_files=delete_extraneous_files, dry_run=dry_run, ) + return IdResponse(result) except Exception as e: # Enhanced error handling for synchronization failures if "not found" in str(e).lower(): diff --git a/src/kili/domain_api/integrations.py b/src/kili/domain_api/integrations.py index 3284192ee..72a3e25fa 100644 --- a/src/kili/domain_api/integrations.py +++ b/src/kili/domain_api/integrations.py @@ -1,12 +1,13 @@ """Integrations domain namespace for the Kili Python SDK.""" -from typing import Dict, Generator, Iterable, List, Literal, Optional, overload +from typing import Generator, Iterable, List, Literal, Optional, overload from typeguard import typechecked from kili.domain.cloud_storage import DataIntegrationPlatform, DataIntegrationStatus from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.integration import IntegrationView, validate_integration from kili.presentation.client.cloud_storage import CloudStorageClientMethods @@ -81,7 +82,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[IntegrationView, None, None]: ... @overload @@ -98,7 +99,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[IntegrationView]: ... @typechecked @@ -115,7 +116,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[IntegrationView]: """Get a generator or a list of cloud storage integrations that match a set of criteria. This method provides a simplified interface for querying cloud storage integrations, @@ -176,7 +177,7 @@ def list( ... ) """ # Access the legacy method directly by calling it from the mixin class - return CloudStorageClientMethods.cloud_storage_integrations( + result = CloudStorageClientMethods.cloud_storage_integrations( self.client, cloud_storage_integration_id=integration_id, name=name, @@ -190,6 +191,20 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with IntegrationView + if as_generator: + # Create intermediate generator - iter() makes result explicitly iterable + def _wrap_generator() -> Generator[IntegrationView, None, None]: + result_iter = iter(result) + for item in result_iter: + yield IntegrationView(validate_integration(item)) + + return _wrap_generator() + + # Convert to list - list() makes result explicitly iterable + result_list = list(result) + return [IntegrationView(validate_integration(item)) for item in result_list] + @typechecked def count( self, @@ -269,7 +284,7 @@ def create( s3_region: Optional[str] = None, s3_secret_key: Optional[str] = None, s3_session_token: Optional[str] = None, - ) -> Dict: + ) -> IntegrationView: """Create a new cloud storage integration. This method creates a new integration with external cloud storage providers, @@ -303,7 +318,7 @@ def create( s3_session_token: S3 session token for temporary credentials. Returns: - A dictionary containing the created integration information. + An IntegrationView with the created integration information. Raises: ValueError: If required parameters for the specified platform are missing. @@ -348,8 +363,11 @@ def create( ... allowed_paths=["/datasets", "/models"] ... ) - >>> # Access the integration ID - >>> integration_id = result["id"] + >>> # Access integration properties + >>> integration_id = result.id + >>> integration_name = result.name + >>> integration_platform = result.platform + >>> integration_status = result.status """ # Validate input parameters if not name or not name.strip(): @@ -370,7 +388,7 @@ def create( # Access the legacy method directly by calling it from the mixin class try: - return CloudStorageClientMethods.create_cloud_storage_integration( + result = CloudStorageClientMethods.create_cloud_storage_integration( self.client, platform=platform, name=name, @@ -394,6 +412,7 @@ def create( s3_secret_key=s3_secret_key, s3_session_token=s3_session_token, ) + return IntegrationView(validate_integration(result)) except Exception as e: # Enhanced error handling for creation failures if "credential" in str(e).lower() or "authentication" in str(e).lower(): @@ -443,7 +462,7 @@ def update( s3_region: Optional[str] = None, s3_secret_key: Optional[str] = None, s3_session_token: Optional[str] = None, - ) -> Dict: + ) -> IntegrationView: """Update an existing cloud storage integration. This method allows you to modify the configuration of an existing cloud storage @@ -476,7 +495,7 @@ def update( s3_session_token: S3 session token for temporary credentials. Returns: - A dictionary containing the updated integration information. + An IntegrationView with the updated integration information. Raises: ValueError: If integration_id is invalid or empty. @@ -510,6 +529,12 @@ def update( ... integration_id="integration_123", ... azure_sas_token="sv=2020-08-04&ss=bfqt&srt=sco&sp=rwdlacupx&se=..." ... ) + + >>> # Access updated integration properties + >>> integration_id = result.id + >>> integration_name = result.name + >>> integration_platform = result.platform + >>> integration_status = result.status """ # Validate input parameters if not integration_id or not integration_id.strip(): @@ -517,7 +542,7 @@ def update( # Access the legacy method directly by calling it from the mixin class try: - return CloudStorageClientMethods.update_cloud_storage_integration( + result = CloudStorageClientMethods.update_cloud_storage_integration( self.client, cloud_storage_integration_id=integration_id, allowed_paths=allowed_paths, @@ -543,6 +568,7 @@ def update( s3_session_token=s3_session_token, status=status, ) + return IntegrationView(validate_integration(result)) except Exception as e: # Enhanced error handling for update failures if "not found" in str(e).lower(): diff --git a/src/kili/domain_api/issues.py b/src/kili/domain_api/issues.py index 0f51203d4..41d8bd57d 100644 --- a/src/kili/domain_api/issues.py +++ b/src/kili/domain_api/issues.py @@ -5,7 +5,7 @@ """ from itertools import repeat -from typing import Any, Dict, Generator, List, Literal, Optional, Union, overload +from typing import Generator, List, Literal, Optional, Union, overload from typeguard import typechecked @@ -15,6 +15,8 @@ from kili.domain.project import ProjectId from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.issue import IssueView, validate_issue +from kili.domain_v2.project import IdListResponse, StatusResponse from kili.presentation.client.helpers.common_validators import ( assert_all_arrays_have_same_size, disable_tqdm_if_as_generator, @@ -87,7 +89,7 @@ def list( status: Optional[IssueStatus] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[IssueView, None, None]: ... @overload @@ -110,7 +112,7 @@ def list( status: Optional[IssueStatus] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[IssueView]: ... @typechecked @@ -133,7 +135,7 @@ def list( status: Optional[IssueStatus] = None, *, as_generator: bool = False, - ) -> Union[Generator[Dict, None, None], List[Dict]]: + ) -> Union[Generator[IssueView, None, None], List[IssueView]]: """Get a generator or a list of issues that match a set of criteria. !!! Info "Issues or Questions" @@ -156,7 +158,7 @@ def list( as_generator: If `True`, a generator on the issues is returned. Returns: - An iterable of issues objects represented as `dict`. + An iterable of IssueView objects. Raises: ValueError: If both `asset_id` and `asset_id_in` are provided. @@ -199,9 +201,12 @@ def list( issue_use_cases = IssueUseCases(self.gateway) issues_gen = issue_use_cases.list_issues(filters=filters, fields=fields, options=options) + # Wrap each issue dict with IssueView + issues_view_gen = (IssueView(validate_issue(issue)) for issue in issues_gen) + if as_generator: - return issues_gen - return list(issues_gen) + return issues_view_gen + return list(issues_view_gen) @typechecked def count( @@ -262,7 +267,7 @@ def create( label_id_array: List[str], object_mid_array: Optional[List[Optional[str]]] = None, text_array: Optional[List[Optional[str]]] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Create issues for the specified labels. Args: @@ -272,7 +277,8 @@ def create( text_array: List of texts to associate to the issues. Returns: - A list of dictionaries with the `id` key of the created issues. + IdListResponse object containing the created issue IDs. + Access the IDs with `.ids` property. Raises: ValueError: If the input arrays have different sizes. @@ -284,6 +290,7 @@ def create( ... label_id_array=["label_123", "label_456"], ... text_array=["Issue with annotation", "Quality concern"] ... ) + >>> print(result.ids) # ['issue_1', 'issue_2'] >>> # Create issues with object associations >>> result = kili.issues.create( @@ -292,6 +299,7 @@ def create( ... object_mid_array=["obj_mid_789"], ... text_array=["Object-specific issue"] ... ) + >>> issue_ids = result.ids # Access created issue IDs """ assert_all_arrays_have_same_size([label_id_array, object_mid_array, text_array]) @@ -306,10 +314,11 @@ def create( issue_use_cases = IssueUseCases(self.gateway) issue_ids = issue_use_cases.create_issues(project_id=ProjectId(project_id), issues=issues) - return [{"id": issue_id} for issue_id in issue_ids] + results = [{"id": issue_id} for issue_id in issue_ids] + return IdListResponse(results) @typechecked - def cancel(self, issue_ids: List[str]) -> List[Dict[str, Any]]: + def cancel(self, issue_ids: List[str]) -> List[StatusResponse]: """Cancel issues by setting their status to CANCELLED. This method provides a more intuitive interface than the generic `update_issue_status` @@ -320,19 +329,24 @@ def cancel(self, issue_ids: List[str]) -> List[Dict[str, Any]]: issue_ids: List of issue IDs to cancel. Returns: - List of dictionaries with the results of the status updates. + List of StatusResponse objects containing the results of the status updates. + Each response has `.id`, `.status`, `.success`, and `.error` properties. Raises: ValueError: If any issue ID is invalid or status transition is not allowed. Examples: >>> # Cancel single issue - >>> result = kili.issues.cancel(issue_ids=["issue_123"]) + >>> results = kili.issues.cancel(issue_ids=["issue_123"]) + >>> for result in results: + ... print(f"Issue {result.id}: {'success' if result.success else result.error}") >>> # Cancel multiple issues - >>> result = kili.issues.cancel( + >>> results = kili.issues.cancel( ... issue_ids=["issue_123", "issue_456", "issue_789"] ... ) + >>> successful = [r.id for r in results if r.success] + >>> failed = [r.id for r in results if not r.success] """ issue_use_cases = IssueUseCases(self.gateway) results = [] @@ -348,10 +362,10 @@ def cancel(self, issue_ids: List[str]) -> List[Dict[str, Any]]: {"id": issue_id, "status": "CANCELLED", "success": False, "error": str(e)} ) - return results + return [StatusResponse(r) for r in results] @typechecked - def open(self, issue_ids: List[str]) -> List[Dict[str, Any]]: + def open(self, issue_ids: List[str]) -> List[StatusResponse]: """Open issues by setting their status to OPEN. This method provides a more intuitive interface than the generic `update_issue_status` @@ -362,19 +376,24 @@ def open(self, issue_ids: List[str]) -> List[Dict[str, Any]]: issue_ids: List of issue IDs to open. Returns: - List of dictionaries with the results of the status updates. + List of StatusResponse objects containing the results of the status updates. + Each response has `.id`, `.status`, `.success`, and `.error` properties. Raises: ValueError: If any issue ID is invalid or status transition is not allowed. Examples: >>> # Open single issue - >>> result = kili.issues.open(issue_ids=["issue_123"]) + >>> results = kili.issues.open(issue_ids=["issue_123"]) + >>> for result in results: + ... if result.success: + ... print(f"Successfully opened issue {result.id}") >>> # Reopen multiple issues - >>> result = kili.issues.open( + >>> results = kili.issues.open( ... issue_ids=["issue_123", "issue_456", "issue_789"] ... ) + >>> print(f"Opened {sum(1 for r in results if r.success)} issues") """ issue_use_cases = IssueUseCases(self.gateway) results = [] @@ -390,10 +409,10 @@ def open(self, issue_ids: List[str]) -> List[Dict[str, Any]]: {"id": issue_id, "status": "OPEN", "success": False, "error": str(e)} ) - return results + return [StatusResponse(r) for r in results] @typechecked - def solve(self, issue_ids: List[str]) -> List[Dict[str, Any]]: + def solve(self, issue_ids: List[str]) -> List[StatusResponse]: """Solve issues by setting their status to SOLVED. This method provides a more intuitive interface than the generic `update_issue_status` @@ -404,19 +423,25 @@ def solve(self, issue_ids: List[str]) -> List[Dict[str, Any]]: issue_ids: List of issue IDs to solve. Returns: - List of dictionaries with the results of the status updates. + List of StatusResponse objects containing the results of the status updates. + Each response has `.id`, `.status`, `.success`, and `.error` properties. Raises: ValueError: If any issue ID is invalid or status transition is not allowed. Examples: >>> # Solve single issue - >>> result = kili.issues.solve(issue_ids=["issue_123"]) + >>> results = kili.issues.solve(issue_ids=["issue_123"]) + >>> if results[0].success: + ... print(f"Issue {results[0].id} resolved successfully") >>> # Solve multiple issues - >>> result = kili.issues.solve( + >>> results = kili.issues.solve( ... issue_ids=["issue_123", "issue_456", "issue_789"] ... ) + >>> errors = [(r.id, r.error) for r in results if not r.success] + >>> if errors: + ... print(f"Failed to solve: {errors}") """ issue_use_cases = IssueUseCases(self.gateway) results = [] @@ -432,7 +457,7 @@ def solve(self, issue_ids: List[str]) -> List[Dict[str, Any]]: {"id": issue_id, "status": "SOLVED", "success": False, "error": str(e)} ) - return results + return [StatusResponse(r) for r in results] def _validate_status_transition( self, issue_id: str, current_status: IssueStatus, new_status: IssueStatus diff --git a/src/kili/domain_api/labels.py b/src/kili/domain_api/labels.py index 6d135472c..5c302f70d 100644 --- a/src/kili/domain_api/labels.py +++ b/src/kili/domain_api/labels.py @@ -27,6 +27,8 @@ from kili.domain.label import LabelType from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.label import LabelExportResponse, LabelView, validate_label +from kili.domain_v2.project import IdListResponse, IdResponse from kili.services.export.types import CocoAnnotationModifier, LabelFormat, SplitOption from kili.utils.labels.parsing import ParsedLabel @@ -56,7 +58,7 @@ def create( asset_id_array: Optional[List[str]] = None, disable_tqdm: Optional[bool] = None, overwrite: bool = False, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Create predictions for specific assets. Args: @@ -71,10 +73,20 @@ def create( the same model name on the targeted assets. Returns: - A dictionary with the project `id`. + An IdResponse object containing the project ID. + + Example: + >>> response = kili.labels.predictions.create( + ... project_id="project_123", + ... external_id_array=["asset_1"], + ... json_response_array=[{"categories": [{"name": "CAR"}]}], + ... model_name="my_model" + ... ) + >>> print(response.id) + 'project_123' """ # Call the client method directly to bypass namespace routing - return self._parent.client.create_predictions( + result = self._parent.client.create_predictions( project_id=project_id, external_id_array=external_id_array, model_name_array=model_name_array, @@ -84,6 +96,7 @@ def create( disable_tqdm=disable_tqdm, overwrite=overwrite, ) + return IdResponse(result) @overload def list( @@ -117,7 +130,7 @@ def list( category_search: Optional[str] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[LabelView, None, None]: ... @overload @@ -152,7 +165,7 @@ def list( category_search: Optional[str] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[LabelView]: ... @typechecked @@ -187,7 +200,7 @@ def list( category_search: Optional[str] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[LabelView]: """Get prediction labels from a project based on a set of criteria. This method is equivalent to the `labels()` method, but it only returns labels of type "PREDICTION". @@ -216,10 +229,10 @@ def list( category_search: Query to filter labels based on the content of their jsonResponse Returns: - An iterable of labels. + An iterable of LabelView objects. """ # Call the client method directly to bypass namespace routing - return self._parent.client.predictions( + results = self._parent.client.predictions( project_id=project_id, asset_id=asset_id, asset_status_in=asset_status_in, @@ -243,6 +256,11 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap each dict result with LabelView + if as_generator: + return (LabelView(validate_label(item)) for item in results) + return [LabelView(validate_label(item)) for item in results] + class InferencesNamespace: """Nested namespace for inference-related operations.""" @@ -287,7 +305,7 @@ def list( category_search: Optional[str] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[LabelView, None, None]: ... @overload @@ -322,7 +340,7 @@ def list( category_search: Optional[str] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[LabelView]: ... @typechecked @@ -357,7 +375,7 @@ def list( category_search: Optional[str] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[LabelView]: """Get inference labels from a project based on a set of criteria. This method is equivalent to the `labels()` method, but it only returns labels of type "INFERENCE". @@ -386,10 +404,10 @@ def list( category_search: Query to filter labels based on the content of their jsonResponse Returns: - An iterable of inference labels. + An iterable of LabelView objects. """ # Call the client method directly to bypass namespace routing - return self._parent.client.inferences( + results = self._parent.client.inferences( project_id=project_id, asset_id=asset_id, asset_status_in=asset_status_in, @@ -413,6 +431,11 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap each dict result with LabelView + if as_generator: + return (LabelView(validate_label(item)) for item in results) + return [LabelView(validate_label(item)) for item in results] + class HoneypotsNamespace: """Nested namespace for honeypot-related operations.""" @@ -432,7 +455,7 @@ def create( asset_external_id: Optional[str] = None, asset_id: Optional[str] = None, project_id: Optional[str] = None, - ) -> Dict: + ) -> LabelView: """Create honeypot for an asset. Uses the given `json_response` to create a `REVIEW` label. @@ -449,15 +472,26 @@ def create( Either provide `asset_id` or `asset_external_id` and `project_id`. Returns: - A dictionary-like object representing the created label. + A LabelView object representing the created honeypot label. + + Example: + >>> label = kili.labels.honeypots.create( + ... asset_id="asset_123", + ... json_response={"categories": [{"name": "CORRECT_ANSWER"}]} + ... ) + >>> print(label.id) + 'label_456' + >>> print(label.label_type) + 'REVIEW' """ # Call the client method directly to bypass namespace routing - return self._parent.client.create_honeypot( + result = self._parent.client.create_honeypot( json_response=json_response, asset_external_id=asset_external_id, asset_id=asset_id, project_id=project_id, ) + return LabelView(validate_label(result)) class EventsNamespace: @@ -602,7 +636,7 @@ def list( output_format: Literal["dict"] = "dict", *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[LabelView, None, None]: ... @overload @@ -642,7 +676,7 @@ def list( output_format: Literal["dict"] = "dict", *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[LabelView]: ... @overload @@ -762,7 +796,7 @@ def list( output_format: Literal["dict", "parsed_label"] = "dict", *, as_generator: bool = False, - ) -> Iterable[Union[Dict, ParsedLabel]]: + ) -> Iterable[Union[LabelView, ParsedLabel]]: """Get a label list or a label generator from a project based on a set of criteria. Args: @@ -790,14 +824,14 @@ def list( disable_tqdm: If `True`, the progress bar will be disabled. as_generator: If `True`, a generator on the labels is returned. category_search: Query to filter labels based on the content of their jsonResponse. - output_format: If `dict`, the output is an iterable of Python dictionaries. + output_format: If `dict`, the output is an iterable of LabelView objects. If `parsed_label`, the output is an iterable of parsed labels objects. Returns: - An iterable of labels. + An iterable of labels (LabelView for dict format, ParsedLabel for parsed_label format). """ # Use super() to bypass namespace routing and call the legacy method directly - return self.client.labels( + results = self.client.labels( project_id=project_id, asset_id=asset_id, asset_status_in=asset_status_in, @@ -824,6 +858,16 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap dict results with LabelView, keep ParsedLabel unchanged + if output_format == "parsed_label": + # Return ParsedLabel objects as-is + return results + + # Wrap dict results with LabelView + if as_generator: + return (LabelView(validate_label(item)) for item in results) + return [LabelView(validate_label(item)) for item in results] + @typechecked def count( self, @@ -908,7 +952,7 @@ def create( disable_tqdm: Optional[bool] = None, overwrite: bool = False, step_name: Optional[str] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Create labels to assets. Args: @@ -929,10 +973,22 @@ def create( The label_type must match accordingly. Returns: - A list of dictionaries with the label ids. + An IdListResponse containing the created label IDs. + + Example: + >>> response = kili.labels.create( + ... asset_id_array=["asset_1", "asset_2"], + ... json_response_array=[ + ... {"categories": [{"name": "CAR"}]}, + ... {"categories": [{"name": "TRUCK"}]} + ... ], + ... project_id="project_123" + ... ) + >>> print(response.ids) + ['label_1', 'label_2'] """ # Use super() to bypass namespace routing and call the legacy method directly - return self.client.append_labels( + result = self.client.append_labels( asset_id_array=asset_id_array, json_response_array=json_response_array, author_id_array=author_id_array, @@ -945,6 +1001,7 @@ def create( overwrite=overwrite, step_name=step_name, ) + return IdListResponse(result) @typechecked def delete( @@ -982,13 +1039,13 @@ def export( normalized_coordinates: Optional[bool] = None, label_type_in: Optional[List[str]] = None, include_sent_back_labels: Optional[bool] = None, - ) -> Optional[List[Dict[str, Union[List[str], str]]]]: + ) -> Optional[LabelExportResponse]: """Export the project labels with the requested format into the requested output path. Args: project_id: Identifier of the project. filename: Relative or full path of the archive that will contain - the exported data. + the exported data. If None, returns export data in memory. fmt: Format of the exported labels. asset_ids: Optional list of the assets internal IDs from which to export the labels. layout: Layout of the exported files. "split" means there is one folder @@ -1012,10 +1069,28 @@ def export( include_sent_back_labels: If True, the export will include the labels that have been sent back. Returns: - Export information or None if export failed. + A LabelExportResponse object containing export information if filename is None, + otherwise None (data is saved to file). + + Example: + >>> # Export to file + >>> kili.labels.export( + ... project_id="project_123", + ... filename="export.zip", + ... fmt="kili" + ... ) + >>> + >>> # Export to memory + >>> export_result = kili.labels.export( + ... project_id="project_123", + ... filename=None, + ... fmt="kili" + ... ) + >>> if export_result: + ... print(len(export_result.export_info)) """ # Use super() to bypass namespace routing and call the legacy method directly - return self.client.export_labels( + result = self.client.export_labels( project_id=project_id, filename=filename, fmt=fmt, @@ -1031,6 +1106,9 @@ def export( label_type_in=label_type_in, include_sent_back_labels=include_sent_back_labels, ) + if result is None: + return None + return LabelExportResponse(result) @typechecked def append( @@ -1046,7 +1124,7 @@ def append( disable_tqdm: Optional[bool] = None, overwrite: bool = False, step_name: Optional[str] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Append labels to assets. This is an alias for the `create` method to maintain compatibility. @@ -1069,7 +1147,16 @@ def append( The label_type must match accordingly. Returns: - A list of dictionaries with the label ids. + An IdListResponse containing the created label IDs. + + Example: + >>> response = kili.labels.append( + ... asset_id_array=["asset_1"], + ... json_response_array=[{"categories": [{"name": "CAR"}]}], + ... project_id="project_123" + ... ) + >>> print(response.ids) + ['label_1'] """ return self.create( asset_id_array=asset_id_array, diff --git a/src/kili/domain_api/notifications.py b/src/kili/domain_api/notifications.py index 85a4306c0..a1a516cd6 100644 --- a/src/kili/domain_api/notifications.py +++ b/src/kili/domain_api/notifications.py @@ -1,6 +1,6 @@ """Notifications domain namespace for the Kili Python SDK.""" -from typing import Dict, Generator, List, Literal, Optional, Union, overload +from typing import Generator, List, Literal, Optional, Union, overload from typeguard import typechecked @@ -9,6 +9,7 @@ from kili.domain.types import ListOrTuple from kili.domain.user import UserFilter, UserId from kili.domain_api.base import DomainNamespace +from kili.domain_v2.notification import NotificationView, validate_notification from kili.entrypoints.mutations.notification.queries import ( GQL_CREATE_NOTIFICATION, GQL_UPDATE_PROPERTIES_IN_NOTIFICATION, @@ -79,7 +80,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[NotificationView, None, None]: ... @overload @@ -94,7 +95,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[NotificationView]: ... @typechecked @@ -109,7 +110,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Union[List[Dict], Generator[Dict, None, None]]: + ) -> Union[List[NotificationView], Generator[NotificationView, None, None]]: """List notifications matching the specified criteria. Args: @@ -167,8 +168,8 @@ def list( ) if as_generator: - return notifications_gen - return list(notifications_gen) + return (NotificationView(validate_notification(item)) for item in notifications_gen) + return [NotificationView(validate_notification(item)) for item in notifications_gen] @typechecked def count( @@ -214,7 +215,7 @@ def create( status: str, url: str, user_id: str, - ) -> Dict: + ) -> NotificationView: """Create a new notification. This method is currently only available for Kili administrators. @@ -226,16 +227,22 @@ def create( user_id: The ID of the user who should receive the notification. Returns: - A result dictionary indicating if the creation was successful. + A NotificationView with the created notification information. Examples: >>> # Create an info notification - >>> result = kili.notifications.create( + >>> notification = kili.notifications.create( ... message="Your project export is ready", ... status="info", ... url="/project/123/export", ... user_id="user_456" ... ) + >>> print(notification.id) + 'notif_789' + >>> print(notification.message) + 'Your project export is ready' + >>> print(notification.status) + 'info' """ # Access the mutations directly from the gateway's GraphQL client # This follows the pattern used in other domain namespaces @@ -250,8 +257,8 @@ def create( } result = self.gateway.graphql_client.execute(GQL_CREATE_NOTIFICATION, variables) - # Format result following the pattern from base operations - return result.get("data", {}) + notification_data = result["data"]["data"] + return NotificationView(validate_notification(notification_data)) @typechecked def update( @@ -262,7 +269,7 @@ def update( url: Optional[str] = None, progress: Optional[int] = None, task_id: Optional[str] = None, - ) -> Dict: + ) -> NotificationView: """Update an existing notification. This method is currently only available for Kili administrators. @@ -276,24 +283,30 @@ def update( task_id: Associated task ID for the notification. Returns: - A result dictionary indicating if the update was successful. + A NotificationView with the updated notification information. Examples: >>> # Mark notification as seen - >>> result = kili.notifications.update( + >>> notification = kili.notifications.update( ... notification_id="notif_123", ... has_been_seen=True ... ) + >>> print(notification.has_been_seen) + True >>> # Update notification status and URL - >>> result = kili.notifications.update( + >>> notification = kili.notifications.update( ... notification_id="notif_123", ... status="completed", ... url="/project/123/results" ... ) + >>> print(notification.status) + 'completed' + >>> print(notification.url) + '/project/123/results' >>> # Update progress for a long-running task - >>> result = kili.notifications.update( + >>> notification = kili.notifications.update( ... notification_id="notif_123", ... progress=75 ... ) @@ -310,5 +323,5 @@ def update( result = self.gateway.graphql_client.execute( GQL_UPDATE_PROPERTIES_IN_NOTIFICATION, variables ) - # Format result following the pattern from base operations - return result.get("data", {}) + notification_data = result["data"]["data"] + return NotificationView(validate_notification(notification_data)) diff --git a/src/kili/domain_api/organizations.py b/src/kili/domain_api/organizations.py index 622f65ef6..7710489f4 100644 --- a/src/kili/domain_api/organizations.py +++ b/src/kili/domain_api/organizations.py @@ -1,12 +1,18 @@ """Organizations domain namespace for the Kili Python SDK.""" from datetime import datetime -from typing import Dict, Generator, Iterable, List, Literal, Optional, overload +from typing import Generator, Iterable, List, Literal, Optional, overload from typeguard import typechecked from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.organization import ( + OrganizationMetricsView, + OrganizationView, + validate_organization, + validate_organization_metrics, +) from kili.presentation.client.organization import OrganizationClientMethods @@ -61,7 +67,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[OrganizationView, None, None]: ... @overload @@ -75,7 +81,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[OrganizationView]: ... @typechecked @@ -89,7 +95,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[OrganizationView]: """Get a generator or a list of organizations that match a set of criteria. Args: @@ -128,7 +134,7 @@ def list( ... ) """ # Access the legacy method directly by calling it from the mixin class - return OrganizationClientMethods.organizations( + result = OrganizationClientMethods.organizations( self.client, email=email, organization_id=organization_id, @@ -139,6 +145,20 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with OrganizationView + if as_generator: + # Create intermediate generator - iter() makes result explicitly iterable + def _wrap_generator() -> Generator[OrganizationView, None, None]: + result_iter = iter(result) + for item in result_iter: + yield OrganizationView(validate_organization(item)) + + return _wrap_generator() + + # Convert to list - list() makes result explicitly iterable + result_list = list(result) + return [OrganizationView(validate_organization(item)) for item in result_list] + @typechecked def count( self, @@ -182,7 +202,7 @@ def metrics( "numberOfHours", "numberOfLabeledAssets", ), - ) -> Dict: + ) -> OrganizationMetricsView: """Get organization metrics and analytics. This method provides access to organization-level analytics including @@ -198,7 +218,7 @@ def metrics( - numberOfLabeledAssets: Total number of labeled assets Returns: - A dictionary containing the requested metrics of the organization. + A view object containing the requested metrics of the organization. Examples: >>> # Get default metrics for organization @@ -219,14 +239,16 @@ def metrics( ... ) >>> # Access specific metric values - >>> annotations_count = metrics["numberOfAnnotations"] - >>> hours_spent = metrics["numberOfHours"] + >>> annotations_count = metrics.number_of_annotations + >>> hours_spent = metrics.number_of_hours + >>> labeled_assets = metrics.number_of_labeled_assets """ # Access the legacy method directly by calling it from the mixin class - return OrganizationClientMethods.organization_metrics( + result = OrganizationClientMethods.organization_metrics( self.client, organization_id=organization_id, start_date=start_date, end_date=end_date, fields=fields, ) + return OrganizationMetricsView(validate_organization_metrics(result)) diff --git a/src/kili/domain_api/plugins.py b/src/kili/domain_api/plugins.py index 2f55f8f50..386663ceb 100644 --- a/src/kili/domain_api/plugins.py +++ b/src/kili/domain_api/plugins.py @@ -2,7 +2,7 @@ import json from datetime import datetime -from typing import Dict, List, Optional +from typing import List, Optional from typeguard import typechecked from typing_extensions import LiteralString @@ -15,6 +15,7 @@ ) from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.plugin import PluginView, validate_plugin from kili.services.plugins import ( PluginUploader, WebhookUploader, @@ -230,7 +231,7 @@ def webhooks(self) -> WebhooksNamespace: def list( self, fields: ListOrTuple[str] = ("name", "projectIds", "id", "createdAt", "updatedAt"), - ) -> List[Dict]: + ) -> List[PluginView]: """List all plugins from your organization. Args: @@ -254,9 +255,10 @@ def list( ... 'organizationId', 'archived' ... ]) """ - return PluginQuery(self.gateway.graphql_client, self.gateway.http_client).list( + result = PluginQuery(self.gateway.graphql_client, self.gateway.http_client).list( fields=fields ) + return [PluginView(validate_plugin(item)) for item in result] @typechecked def status( diff --git a/src/kili/domain_api/projects.py b/src/kili/domain_api/projects.py index 1d1463355..1021dbd18 100644 --- a/src/kili/domain_api/projects.py +++ b/src/kili/domain_api/projects.py @@ -7,7 +7,6 @@ from functools import cached_property from typing import ( TYPE_CHECKING, - Any, Dict, Generator, Iterable, @@ -23,6 +22,17 @@ from kili.domain.project import ComplianceTag, InputType, WorkflowStepCreate, WorkflowStepUpdate from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.project import ( + IdResponse, + ProjectRoleView, + ProjectVersionView, + ProjectView, + WorkflowStepView, + validate_project, + validate_project_role, + validate_project_version, + validate_workflow_step, +) if TYPE_CHECKING: from kili.client import Kili as KiliLegacy @@ -40,7 +50,7 @@ def __init__(self, parent: "ProjectsNamespace") -> None: self._parent = parent @typechecked - def update(self, project_id: str, should_anonymize: bool = True) -> Dict[Literal["id"], str]: + def update(self, project_id: str, should_anonymize: bool = True) -> IdResponse: """Anonymize the project for the labelers and reviewers. Args: @@ -48,16 +58,16 @@ def update(self, project_id: str, should_anonymize: bool = True) -> Dict[Literal should_anonymize: The value to be applied. Defaults to `True`. Returns: - A dict with the id of the project which indicates if the mutation was successful, - or an error message. + An IdResponse with the project id indicating if the mutation was successful. Examples: >>> projects.anonymization.update(project_id=project_id) >>> projects.anonymization.update(project_id=project_id, should_anonymize=False) """ - return self._parent.client.update_project_anonymization( + result = self._parent.client.update_project_anonymization( project_id=project_id, should_anonymize=should_anonymize ) + return IdResponse(result) class UsersNamespace: @@ -77,7 +87,7 @@ def add( project_id: str, email: str, role: Literal["ADMIN", "TEAM_MANAGER", "REVIEWER", "LABELER"] = "LABELER", - ) -> Dict: + ) -> ProjectRoleView: """Add a user to a project. If the user does not exist in your organization, he/she is invited and added @@ -91,29 +101,31 @@ def add( role: The role of the user. Returns: - A dictionary with the project user information. + A ProjectRoleView with the project user information. Examples: >>> projects.users.add(project_id=project_id, email='john@doe.com') """ - return self._parent.client.append_to_roles( + result = self._parent.client.append_to_roles( project_id=project_id, user_email=email, role=role ) + return ProjectRoleView(validate_project_role(result)) @typechecked - def remove(self, role_id: str) -> Dict[Literal["id"], str]: + def remove(self, role_id: str) -> IdResponse: """Remove users by their role_id. Args: role_id: Identifier of the project user (not the ID of the user) Returns: - A dict with the project id. + An IdResponse with the project id. """ - return self._parent.client.delete_from_roles(role_id=role_id) + result = self._parent.client.delete_from_roles(role_id=role_id) + return IdResponse(result) @typechecked - def update(self, role_id: str, project_id: str, user_id: str, role: str) -> Dict: + def update(self, role_id: str, project_id: str, user_id: str, role: str) -> ProjectRoleView: """Update properties of a role. To be able to change someone's role, you must be either of: @@ -129,11 +141,12 @@ def update(self, role_id: str, project_id: str, user_id: str, role: str) -> Dict Possible choices are: `ADMIN`, `TEAM_MANAGER`, `REVIEWER`, `LABELER` Returns: - A dictionary with the project user information. + A ProjectRoleView with the project user information. """ - return self._parent.client.update_properties_in_role( + result = self._parent.client.update_properties_in_role( role_id=role_id, project_id=project_id, user_id=user_id, role=role ) + return ProjectRoleView(validate_project_role(result)) @overload def list( @@ -159,7 +172,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[ProjectRoleView, None, None]: ... @overload @@ -186,7 +199,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[ProjectRoleView]: ... @typechecked @@ -213,7 +226,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[ProjectRoleView]: """Get project users from projects that match a set of criteria. Args: @@ -257,18 +270,20 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) - # Extract roles from projects + # Extract roles from projects and wrap with ProjectRoleView if as_generator: def users_generator(): for project in projects: - yield from project.get("roles", []) + for role in project.get("roles", []): + yield ProjectRoleView(validate_project_role(role)) return users_generator() users = [] for project in projects: - users.extend(project.get("roles", [])) + for role in project.get("roles", []): + users.append(ProjectRoleView(validate_project_role(role))) return users @typechecked @@ -331,7 +346,7 @@ def __init__(self, parent: "WorkflowNamespace") -> None: self._parent = parent @typechecked - def list(self, project_id: str) -> List[Dict[str, Any]]: + def list(self, project_id: str) -> List[WorkflowStepView]: """Get steps in a project workflow. Args: @@ -340,7 +355,8 @@ def list(self, project_id: str) -> List[Dict[str, Any]]: Returns: A list with the steps of the project workflow. """ - return self._parent._parent.client.get_steps(project_id=project_id) # pylint: disable=protected-access + steps = self._parent._parent.client.get_steps(project_id=project_id) # pylint: disable=protected-access + return [WorkflowStepView(validate_workflow_step(step)) for step in steps] class WorkflowNamespace: @@ -371,7 +387,7 @@ def update( create_steps: Optional[List[WorkflowStepCreate]] = None, update_steps: Optional[List[WorkflowStepUpdate]] = None, delete_steps: Optional[List[str]] = None, - ) -> Dict[str, Any]: + ) -> IdResponse: """Update properties of a project workflow. Args: @@ -384,16 +400,16 @@ def update( delete_steps: List of step IDs to delete from the project workflow. Returns: - A dict with the changed properties which indicates if the mutation was successful, - else an error message. + An IdResponse with the project id indicating if the mutation was successful. """ - return self._parent.client.update_project_workflow( + result = self._parent.client.update_project_workflow( project_id=project_id, enforce_step_separation=enforce_step_separation, create_steps=create_steps, update_steps=update_steps, delete_steps=delete_steps, ) + return IdResponse(result) class VersionsNamespace: @@ -417,7 +433,7 @@ def get( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[ProjectVersionView, None, None]: ... @overload @@ -430,7 +446,7 @@ def get( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[ProjectVersionView]: ... @typechecked @@ -443,7 +459,7 @@ def get( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[ProjectVersionView]: """Get a generator or a list of project versions respecting a set of criteria. Args: @@ -456,9 +472,9 @@ def get( as_generator: If `True`, a generator on the project versions is returned. Returns: - An iterable of dictionaries containing the project versions information. + An iterable of project version views. """ - return self._parent.client.project_version( + results = self._parent.client.project_version( project_id=project_id, first=first, skip=skip, @@ -467,6 +483,11 @@ def get( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with ProjectVersionView + if as_generator: + return (ProjectVersionView(validate_project_version(item)) for item in results) + return [ProjectVersionView(validate_project_version(item)) for item in results] + @typechecked def count(self, project_id: str) -> int: """Count the number of project versions. @@ -480,7 +501,7 @@ def count(self, project_id: str) -> int: return self._parent.client.count_project_versions(project_id=project_id) @typechecked - def update(self, project_version_id: str, content: Optional[str]) -> Dict: + def update(self, project_version_id: str, content: Optional[str]) -> ProjectVersionView: """Update properties of a project version. Args: @@ -488,7 +509,7 @@ def update(self, project_version_id: str, content: Optional[str]) -> Dict: content: Link to download the project version Returns: - A dictionary containing the updated project version. + A ProjectVersionView containing the updated project version. Examples: >>> projects.versions.update( @@ -496,9 +517,10 @@ def update(self, project_version_id: str, content: Optional[str]) -> Dict: content='test' ) """ - return self._parent.client.update_properties_in_project_version( + result = self._parent.client.update_properties_in_project_version( project_version_id=project_version_id, content=content ) + return ProjectVersionView(validate_project_version(result)) class ProjectsNamespace(DomainNamespace): @@ -586,7 +608,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[ProjectView, None, None]: ... @overload @@ -620,7 +642,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[ProjectView]: ... @typechecked @@ -654,7 +676,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[ProjectView]: """Get a generator or a list of projects that match a set of criteria. Args: @@ -685,7 +707,7 @@ def list( >>> # List all my projects >>> projects.list() """ - return self.client.projects( + results = self.client.projects( project_id=project_id, search_query=search_query, should_relaunch_kpi_computation=should_relaunch_kpi_computation, @@ -703,6 +725,11 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with ProjectView + if as_generator: + return (ProjectView(validate_project(item)) for item in results) + return [ProjectView(validate_project(item)) for item in results] + @typechecked def count( self, @@ -755,7 +782,7 @@ def create( tags: Optional[ListOrTuple[str]] = None, compliance_tags: Optional[ListOrTuple[ComplianceTag]] = None, from_demo_project: Optional[DemoProjectType] = None, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Create a project. Args: @@ -772,12 +799,12 @@ def create( from_demo_project: Demo project type to create from. Returns: - A dict with the id of the created project. + An IdResponse with the id of the created project. Examples: >>> projects.create(input_type='IMAGE', json_interface=json_interface, title='Example') """ - return self.client.create_project( + result = self.client.create_project( title=title, description=description, input_type=input_type, @@ -787,6 +814,7 @@ def create( compliance_tags=compliance_tags, from_demo_project=from_demo_project, ) + return IdResponse(result) @typechecked def update( @@ -811,7 +839,7 @@ def update( metadata_properties: Optional[dict] = None, seconds_to_label_before_auto_assign: Optional[int] = None, should_auto_assign: Optional[bool] = None, - ) -> Dict[str, Any]: + ) -> IdResponse: """Update properties of a project. Args: @@ -843,10 +871,9 @@ def update( should_auto_assign: If `True`, assets are automatically assigned to users when they start annotating. Returns: - A dict with the changed properties which indicates if the mutation was successful, - else an error message. + An IdResponse with the project id indicating if the mutation was successful. """ - return self.client.update_properties_in_project( + result = self.client.update_properties_in_project( project_id=project_id, can_navigate_between_assets=can_navigate_between_assets, can_skip_asset=can_skip_asset, @@ -868,30 +895,33 @@ def update( seconds_to_label_before_auto_assign=seconds_to_label_before_auto_assign, should_auto_assign=should_auto_assign, ) + return IdResponse(result) @typechecked - def archive(self, project_id: str) -> Dict[Literal["id"], str]: + def archive(self, project_id: str) -> IdResponse: """Archive a project. Args: project_id: Identifier of the project. Returns: - A dict with the id of the project. + An IdResponse with the id of the project. """ - return self.client.archive_project(project_id=project_id) + result = self.client.archive_project(project_id=project_id) + return IdResponse(result) @typechecked - def unarchive(self, project_id: str) -> Dict[Literal["id"], str]: + def unarchive(self, project_id: str) -> IdResponse: """Unarchive a project. Args: project_id: Identifier of the project Returns: - A dict with the id of the project. + An IdResponse with the id of the project. """ - return self.client.unarchive_project(project_id=project_id) + result = self.client.unarchive_project(project_id=project_id) + return IdResponse(result) @typechecked def copy( @@ -940,13 +970,15 @@ def copy( ) @typechecked - def delete(self, project_id: str) -> str: + def delete(self, project_id: str) -> IdResponse: """Delete a project permanently. Args: project_id: Identifier of the project Returns: - A string with the deleted project id. + An IdResponse with the deleted project id. """ - return self.client.delete_project(project_id=project_id) + result = self.client.delete_project(project_id=project_id) + # delete_project returns a string ID, so wrap it in a dict + return IdResponse({"id": result}) diff --git a/src/kili/domain_api/tags.py b/src/kili/domain_api/tags.py index ca5ca17db..0013c8abb 100644 --- a/src/kili/domain_api/tags.py +++ b/src/kili/domain_api/tags.py @@ -1,6 +1,6 @@ """Tags domain namespace for the Kili Python SDK.""" -from typing import Dict, List, Literal, Optional +from typing import List, Optional, cast from typeguard import typechecked @@ -8,6 +8,8 @@ from kili.domain.tag import TagId from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.project import IdListResponse, IdResponse +from kili.domain_v2.tag import TagView, validate_tag from kili.use_cases.tag import TagUseCases @@ -69,7 +71,7 @@ def list( self, project_id: Optional[str] = None, fields: Optional[ListOrTuple[str]] = None, - ) -> List[Dict]: + ) -> List[TagView]: """List tags from the organization or a specific project. Args: @@ -95,18 +97,19 @@ def list( fields = ("id", "organizationId", "label", "checkedForProjects") tag_use_cases = TagUseCases(self.gateway) - return ( + result = ( tag_use_cases.get_tags_of_organization(fields=fields) if project_id is None else tag_use_cases.get_tags_of_project(project_id=ProjectId(project_id), fields=fields) ) + return [TagView(validate_tag(item)) for item in result] @typechecked def create( self, name: str, color: Optional[str] = None, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Create a new tag in the organization. This operation is organization-wide. @@ -117,17 +120,19 @@ def create( color: Color of the tag to create. If not provided, a default color will be used. Returns: - Dictionary with the ID of the created tag. + An IdResponse with the ID of the created tag. Examples: >>> # Create a simple tag >>> result = kili.tags.create(name="reviewed") + >>> print(result.id) >>> # Create a tag with a specific color >>> result = kili.tags.create(name="important", color="#ff0000") """ tag_use_cases = TagUseCases(self.gateway) - return tag_use_cases.create_tag(name, color) + result = tag_use_cases.create_tag(name, color) + return IdResponse(result) @typechecked def update( @@ -135,7 +140,7 @@ def update( new_name: str, tag_name: Optional[str] = None, tag_id: Optional[str] = None, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Update an existing tag. This operation is organization-wide. @@ -147,7 +152,7 @@ def update( new_name: New name for the tag. Returns: - Dictionary with the ID of the updated tag. + An IdResponse with the ID of the updated tag. Raises: ValueError: If neither tag_name nor tag_id is provided. @@ -155,6 +160,7 @@ def update( Examples: >>> # Update tag by name >>> result = kili.tags.update(new_name="new_name", tag_name="old_name") + >>> print(result.id) >>> # Update tag by ID (more precise) >>> result = kili.tags.update(new_name="new_name", tag_id="tag_id_123") @@ -165,17 +171,18 @@ def update( tag_use_cases = TagUseCases(self.gateway) if tag_id is None: # tag_name is guaranteed to be not None here due to validation above - resolved_tag_id = tag_use_cases.get_tag_ids_from_labels(labels=[tag_name])[0] # type: ignore[list-item] + resolved_tag_id = tag_use_cases.get_tag_ids_from_labels(labels=[cast(str, tag_name)])[0] else: resolved_tag_id = TagId(tag_id) - return { + result = { "id": str( tag_use_cases.update_tag( tag_id=resolved_tag_id, new_tag_name=new_name ).updated_tag_id ) } + return IdResponse(result) @typechecked def delete( @@ -212,7 +219,7 @@ def delete( tag_use_cases = TagUseCases(self.gateway) if tag_id is None: # tag_name is guaranteed to be not None here due to validation above - resolved_tag_id = tag_use_cases.get_tag_ids_from_labels(labels=[tag_name])[0] # type: ignore[list-item] + resolved_tag_id = tag_use_cases.get_tag_ids_from_labels(labels=[cast(str, tag_name)])[0] else: resolved_tag_id = TagId(tag_id) @@ -225,7 +232,7 @@ def assign( tags: Optional[ListOrTuple[str]] = None, tag_ids: Optional[ListOrTuple[str]] = None, disable_tqdm: Optional[bool] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Assign tags to a project. This method replaces the legacy tag_project method with a more intuitive name. @@ -238,7 +245,7 @@ def assign( disable_tqdm: Whether to disable the progress bar. Returns: - List of dictionaries with the assigned tag IDs. + An IdListResponse with the assigned tag IDs. Raises: ValueError: If neither tags nor tag_ids is provided. @@ -249,6 +256,7 @@ def assign( ... project_id="my_project", ... tags=["important", "reviewed"] ... ) + >>> print(result.ids) >>> # Assign tags by ID >>> result = kili.tags.assign( @@ -263,7 +271,9 @@ def assign( if tag_ids is None: # tags is guaranteed to be not None here due to validation above - resolved_tag_ids = tag_use_cases.get_tag_ids_from_labels(labels=tags) # type: ignore[arg-type] + resolved_tag_ids = tag_use_cases.get_tag_ids_from_labels( + labels=cast(ListOrTuple[str], tags) + ) else: resolved_tag_ids = [TagId(tag_id) for tag_id in tag_ids] @@ -273,7 +283,8 @@ def assign( disable_tqdm=disable_tqdm, ) - return [{"id": str(tag_id)} for tag_id in assigned_tag_ids] + results = [{"id": str(tag_id)} for tag_id in assigned_tag_ids] + return IdListResponse(results) @typechecked def unassign( @@ -283,7 +294,7 @@ def unassign( tag_ids: Optional[ListOrTuple[str]] = None, all: Optional[bool] = None, # pylint: disable=redefined-builtin disable_tqdm: Optional[bool] = None, - ) -> List[Dict[Literal["id"], str]]: + ) -> IdListResponse: """Remove tags from a project. This method replaces the legacy untag_project method with a more intuitive name. @@ -296,7 +307,7 @@ def unassign( disable_tqdm: Whether to disable the progress bar. Returns: - List of dictionaries with the unassigned tag IDs. + An IdListResponse with the unassigned tag IDs. Raises: ValueError: If exactly one of tags, tag_ids, or all must be provided. @@ -307,6 +318,7 @@ def unassign( ... project_id="my_project", ... tags=["old_tag", "obsolete"] ... ) + >>> print(result.ids) >>> # Remove specific tags by ID >>> result = kili.tags.unassign( @@ -346,4 +358,5 @@ def unassign( disable_tqdm=disable_tqdm, ) - return [{"id": str(tag_id)} for tag_id in unassigned_tag_ids] + results = [{"id": str(tag_id)} for tag_id in unassigned_tag_ids] + return IdListResponse(results) diff --git a/src/kili/domain_api/users.py b/src/kili/domain_api/users.py index ae56d30be..0e0f200c6 100644 --- a/src/kili/domain_api/users.py +++ b/src/kili/domain_api/users.py @@ -1,13 +1,15 @@ """Users domain namespace for the Kili Python SDK.""" import re -from typing import Dict, Generator, Iterable, List, Literal, Optional, overload +from typing import Generator, Iterable, List, Literal, Optional, overload from typeguard import typechecked from kili.core.enums import OrganizationRole from kili.domain.types import ListOrTuple from kili.domain_api.base import DomainNamespace +from kili.domain_v2.project import IdResponse +from kili.domain_v2.user import UserView, validate_user from kili.presentation.client.user import UserClientMethods @@ -76,7 +78,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[True], - ) -> Generator[Dict, None, None]: + ) -> Generator[UserView, None, None]: ... @overload @@ -91,7 +93,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: Literal[False] = False, - ) -> List[Dict]: + ) -> List[UserView]: ... @typechecked @@ -106,7 +108,7 @@ def list( disable_tqdm: Optional[bool] = None, *, as_generator: bool = False, - ) -> Iterable[Dict]: + ) -> Iterable[UserView]: """Get a generator or a list of users given a set of criteria. Args: @@ -136,7 +138,7 @@ def list( ... ) """ # Access the legacy method directly by calling it from the mixin class - return UserClientMethods.users( + result = UserClientMethods.users( self.client, api_key=api_key, email=email, @@ -148,6 +150,20 @@ def list( as_generator=as_generator, # pyright: ignore[reportGeneralTypeIssues] ) + # Wrap results with UserView + if as_generator: + # Create intermediate generator - iter() makes result explicitly iterable + def _wrap_generator() -> Generator[UserView, None, None]: + result_iter = iter(result) + for item in result_iter: + yield UserView(validate_user(item)) + + return _wrap_generator() + + # Convert to list - list() makes result explicitly iterable + result_list = list(result) + return [UserView(validate_user(item)) for item in result_list] + @typechecked def count( self, @@ -188,7 +204,7 @@ def create( organization_role: OrganizationRole, firstname: Optional[str] = None, lastname: Optional[str] = None, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Add a user to your organization. Args: @@ -199,7 +215,7 @@ def create( lastname: Last name of the new user. Returns: - A dictionary with the id of the new user. + An IdResponse with the id of the new user. Raises: ValueError: If email format is invalid or password is weak. @@ -213,6 +229,7 @@ def create( ... firstname="John", ... lastname="Doe" ... ) + >>> print(result.id) >>> # Create a regular user >>> result = kili.users.create( @@ -220,6 +237,7 @@ def create( ... password="userpassword123", ... organization_role=OrganizationRole.USER ... ) + >>> print(result.id) """ # Validate email format if not self._is_valid_email(email): @@ -231,7 +249,7 @@ def create( "Password must be at least 8 characters long and contain at least one letter and one number" ) - return UserClientMethods.create_user( + result = UserClientMethods.create_user( self.client, email=email, password=password, @@ -239,6 +257,7 @@ def create( firstname=firstname, lastname=lastname, ) + return IdResponse(result) @typechecked def update( @@ -249,7 +268,7 @@ def update( organization_id: Optional[str] = None, organization_role: Optional[OrganizationRole] = None, activated: Optional[bool] = None, - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Update the properties of a user. Args: @@ -262,7 +281,7 @@ def update( activated: In case we want to deactivate a user, but keep it. Returns: - A dict with the user id. + An IdResponse with the user id. Raises: ValueError: If email format is invalid. @@ -274,24 +293,27 @@ def update( ... firstname="UpdatedFirstName", ... lastname="UpdatedLastName" ... ) + >>> print(result.id) >>> # Change user role >>> result = kili.users.update( ... email="user@example.com", ... organization_role=OrganizationRole.ADMIN ... ) + >>> print(result.id) >>> # Deactivate user >>> result = kili.users.update( ... email="user@example.com", ... activated=False ... ) + >>> print(result.id) """ # Validate email format if not self._is_valid_email(email): raise ValueError(f"Invalid email format: {email}") - return UserClientMethods.update_properties_in_user( + result = UserClientMethods.update_properties_in_user( self.client, email=email, firstname=firstname, @@ -300,11 +322,12 @@ def update( organization_role=organization_role, activated=activated, ) + return IdResponse(result) @typechecked def update_password( self, email: str, old_password: str, new_password_1: str, new_password_2: str - ) -> Dict[Literal["id"], str]: + ) -> IdResponse: """Allow to modify the password that you use to connect to Kili. This resolver only works for on-premise installations without Auth0. @@ -317,7 +340,7 @@ def update_password( new_password_2: A confirmation field for the new password Returns: - A dict with the user id. + An IdResponse with the user id. Raises: ValueError: If validation fails for email, password confirmation, @@ -333,18 +356,20 @@ def update_password( ... new_password_1="newpassword456", ... new_password_2="newpassword456" ... ) + >>> print(result.id) """ # Enhanced security validation self._validate_password_update_request(email, old_password, new_password_1, new_password_2) try: - return UserClientMethods.update_password( + result = UserClientMethods.update_password( self.client, email=email, old_password=old_password, new_password_1=new_password_1, new_password_2=new_password_2, ) + return IdResponse(result) except Exception as e: # Enhanced error handling for authentication failures if "authentication" in str(e).lower() or "password" in str(e).lower(): diff --git a/src/kili/domain_v2/__init__.py b/src/kili/domain_v2/__init__.py new file mode 100644 index 000000000..6c1e8f17f --- /dev/null +++ b/src/kili/domain_v2/__init__.py @@ -0,0 +1,141 @@ +"""Domain v2: TypedDict-based domain contracts for Kili SDK. + +This module provides TypedDict-based domain contracts as a lightweight alternative +to dataclasses. These contracts enable better type safety while maintaining the +flexibility of dictionaries. +""" + +from .adapters import ContractValidator, DataFrameAdapter +from .asset import ( + AssetContract, + AssetCreateResponse, + AssetCreateResponseContract, + AssetView, + WorkflowStepResponse, + WorkflowStepResponseContract, + validate_asset, + validate_asset_create_response, + validate_workflow_step_response, +) +from .connection import ConnectionContract, ConnectionView, validate_connection +from .integration import IntegrationContract, IntegrationView, validate_integration +from .issue import IssueContract, IssueView, validate_issue +from .label import ( + LabelContract, + LabelExportResponse, + LabelView, + filter_labels_by_type, + sort_labels_by_created_at, + validate_label, +) +from .notification import NotificationContract, NotificationView, validate_notification +from .organization import ( + OrganizationContract, + OrganizationMetricsContract, + OrganizationMetricsView, + OrganizationView, + validate_organization, + validate_organization_metrics, +) +from .plugin import PluginContract, PluginView, validate_plugin +from .project import ( + IdListResponse, + IdResponse, + ProjectContract, + ProjectRoleView, + ProjectVersionContract, + ProjectVersionView, + ProjectView, + StatusResponse, + WorkflowStepView, + get_ordered_steps, + get_step_by_name, + validate_project, + validate_project_role, + validate_project_version, + validate_workflow_step, +) +from .tag import TagContract, TagView, validate_tag +from .user import ( + UserContract, + UserView, + filter_users_by_activated, + sort_users_by_email, + validate_user, +) + +__all__ = [ + # Asset + "AssetContract", + "AssetCreateResponse", + "AssetCreateResponseContract", + "AssetView", + "WorkflowStepResponse", + "WorkflowStepResponseContract", + "validate_asset", + "validate_asset_create_response", + "validate_workflow_step_response", + # Connection + "ConnectionContract", + "ConnectionView", + "validate_connection", + # Integration + "IntegrationContract", + "IntegrationView", + "validate_integration", + # Issue + "IssueContract", + "IssueView", + "validate_issue", + # Label + "LabelContract", + "LabelExportResponse", + "LabelView", + "filter_labels_by_type", + "sort_labels_by_created_at", + "validate_label", + # Notification + "NotificationContract", + "NotificationView", + "validate_notification", + # Organization + "OrganizationContract", + "OrganizationMetricsContract", + "OrganizationMetricsView", + "OrganizationView", + "validate_organization", + "validate_organization_metrics", + # Plugin + "PluginContract", + "PluginView", + "validate_plugin", + # Project + "IdListResponse", + "IdResponse", + "ProjectContract", + "ProjectRoleView", + "ProjectVersionContract", + "ProjectVersionView", + "ProjectView", + "StatusResponse", + "WorkflowStepView", + "get_ordered_steps", + "get_step_by_name", + "validate_project", + "validate_project_role", + "validate_project_version", + "validate_workflow_step", + # Tag + "TagContract", + "TagView", + "validate_tag", + # User + "UserContract", + "UserView", + "validate_user", + "filter_users_by_activated", + "sort_users_by_email", + # Adapters + "DataFrameAdapter", + "ContractValidator", +] diff --git a/src/kili/domain_v2/adapters.py b/src/kili/domain_v2/adapters.py new file mode 100644 index 000000000..24aa33141 --- /dev/null +++ b/src/kili/domain_v2/adapters.py @@ -0,0 +1,242 @@ +"""DataFrame adapter utilities for domain contracts. + +This module provides utilities to convert between domain contracts and +pandas DataFrames without mutating the original payloads. +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union + +try: + import pandas as pd + + PANDAS_AVAILABLE: bool = True +except ImportError: + PANDAS_AVAILABLE = False + if not TYPE_CHECKING: + pd = None # type: ignore[assignment] + +from .asset import AssetContract, AssetView, validate_asset +from .label import LabelContract, LabelView, validate_label +from .project import ProjectContract, ProjectView, validate_project +from .user import UserContract, UserView, validate_user + +T = TypeVar("T", AssetContract, LabelContract, ProjectContract, UserContract) +V = TypeVar("V", AssetView, LabelView, ProjectView, UserView) + + +class DataFrameAdapter: + """Adapter for converting domain contracts to/from DataFrames. + + This adapter provides methods to convert validated domain contracts + into pandas DataFrames and vice versa, without mutating the original data. + + Example: + >>> assets = [{"id": "1", "externalId": "asset-1"}, ...] + >>> adapter = DataFrameAdapter() + >>> df = adapter.to_dataframe(assets, AssetContract) + >>> contracts = adapter.from_dataframe(df, AssetContract) + """ + + def __init__(self) -> None: + """Initialize the DataFrame adapter.""" + if not PANDAS_AVAILABLE: + raise ImportError( + "pandas is required for DataFrame adapters. Install it with: pip install pandas" + ) + + @staticmethod + def to_dataframe( + contracts: List[T], + contract_type: Optional[Type[T]] = None, + validate: bool = True, + ) -> "pd.DataFrame": + """Convert a list of domain contracts to a DataFrame. + + Args: + contracts: List of domain contracts + contract_type: Type of contract for validation (optional) + validate: If True, validate each contract before conversion + + Returns: + pandas DataFrame with contract data + + Raises: + TypeError: If validation fails + ImportError: If pandas is not available + """ + if not PANDAS_AVAILABLE: + raise ImportError("pandas is required for DataFrame conversion") + + if not contracts: + if not PANDAS_AVAILABLE: + raise ImportError("pandas is required") + assert pd is not None # For type checker + return pd.DataFrame() + + # Validate contracts if requested + if validate and contract_type: + validators = { + AssetContract: validate_asset, + LabelContract: validate_label, + ProjectContract: validate_project, + UserContract: validate_user, + } + validator = validators.get(contract_type) + if validator: + contracts = [validator(c) for c in contracts] # type: ignore + + # Create DataFrame from contracts + # This creates a copy, not mutating original data + if not PANDAS_AVAILABLE: + raise ImportError("pandas is required") + assert pd is not None # For type checker + return pd.DataFrame(contracts) + + @staticmethod + def from_dataframe( + df: "pd.DataFrame", + contract_type: Type[T], + validate: bool = True, + ) -> List[T]: + """Convert a DataFrame to a list of domain contracts. + + Args: + df: pandas DataFrame + contract_type: Type of contract to convert to + validate: If True, validate each contract after conversion + + Returns: + List of domain contracts + + Raises: + TypeError: If validation fails + ImportError: If pandas is not available + """ + if not PANDAS_AVAILABLE: + raise ImportError("pandas is required for DataFrame conversion") + + # Convert DataFrame to list of dictionaries + contracts = df.to_dict("records") + + # Validate contracts if requested + if validate: + validators = { + AssetContract: validate_asset, + LabelContract: validate_label, + ProjectContract: validate_project, + UserContract: validate_user, + } + validator = validators.get(contract_type) + if validator: + contracts = [validator(c) for c in contracts] # type: ignore + + return contracts # type: ignore + + @staticmethod + def wrap_contracts( + contracts: List[T], + view_type: Type[V], + ) -> List[V]: + """Wrap a list of contracts in view wrappers. + + Args: + contracts: List of domain contracts + view_type: Type of view wrapper + + Returns: + List of view wrappers + + Example: + >>> assets = [{"id": "1", ...}, {"id": "2", ...}] + >>> views = DataFrameAdapter.wrap_contracts(assets, AssetView) + >>> print(views[0].display_name) + """ + return [view_type(contract) for contract in contracts] # type: ignore + + @staticmethod + def unwrap_views(views: List[V]) -> List[Dict[str, Any]]: + """Unwrap view wrappers back to dictionaries. + + Args: + views: List of view wrappers + + Returns: + List of dictionaries + """ + return [view.to_dict() for view in views] # type: ignore[misc] + + +class ContractValidator: + """Validator for domain contracts. + + This class provides validation utilities for domain contracts, + including batch validation and error reporting. + """ + + @staticmethod + def validate_batch( + contracts: List[Dict[str, Any]], + contract_type: Type[T], + ) -> tuple[List[T], List[tuple[int, Exception]]]: + """Validate a batch of contracts and collect errors. + + Args: + contracts: List of dictionaries to validate + contract_type: Type of contract to validate against + + Returns: + Tuple of (valid_contracts, errors) + where errors is a list of (index, exception) tuples + """ + validators = { + AssetContract: validate_asset, + LabelContract: validate_label, + ProjectContract: validate_project, + UserContract: validate_user, + } + + validator = validators.get(contract_type) + if not validator: + raise ValueError(f"Unknown contract type: {contract_type}") + + valid_contracts: List[T] = [] + errors: List[tuple[int, Exception]] = [] + + for i, contract in enumerate(contracts): + try: + validated = validator(contract) # type: ignore + valid_contracts.append(validated) # type: ignore + except Exception as e: # pylint: disable=broad-except + errors.append((i, e)) + + return valid_contracts, errors + + @staticmethod + def validate_single( + contract: Dict[str, Any], + contract_type: Type[T], + ) -> Union[T, Exception]: + """Validate a single contract. + + Args: + contract: Dictionary to validate + contract_type: Type of contract to validate against + + Returns: + Validated contract or Exception if validation fails + """ + validators = { + AssetContract: validate_asset, + LabelContract: validate_label, + ProjectContract: validate_project, + UserContract: validate_user, + } + + validator = validators.get(contract_type) + if not validator: + raise ValueError(f"Unknown contract type: {contract_type}") + + try: + return validator(contract) # type: ignore + except Exception as e: # pylint: disable=broad-except + return e diff --git a/src/kili/domain_v2/asset.py b/src/kili/domain_v2/asset.py new file mode 100644 index 000000000..926db3826 --- /dev/null +++ b/src/kili/domain_v2/asset.py @@ -0,0 +1,305 @@ +"""Asset domain contract using TypedDict. + +This module provides a TypedDict-based contract for Asset entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast + +from typeguard import check_type + +# Asset status types from domain/asset/asset.py +AssetStatus = Literal["TODO", "ONGOING", "LABELED", "REVIEWED", "TO_REVIEW"] +StatusInStep = Literal["TO_DO", "DOING", "PARTIALLY_DONE", "REDO", "DONE", "SKIPPED"] + + +class CurrentStepContract(TypedDict, total=False): + """Current workflow step information for an asset.""" + + name: str + status: StatusInStep + + +class LabelAuthorContract(TypedDict, total=False): + """Label author information.""" + + id: str + email: str + + +class AssetLabelContract(TypedDict, total=False): + """Embedded label information within an asset.""" + + id: str + author: LabelAuthorContract + createdAt: str + jsonResponse: Dict[str, Any] + + +class AssetContract(TypedDict, total=False): + """TypedDict contract for Asset entities. + + This contract represents the core structure of an Asset as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the asset + externalId: External reference ID + content: Asset content (URL or text) + jsonMetadata: Custom metadata dictionary + labels: List of labels associated with this asset + status: Current asset status (workflow v1) + currentStep: Current workflow step (workflow v2) + isHoneypot: Whether this is a honeypot asset for quality control + skipped: Whether the asset has been skipped + createdAt: ISO timestamp of creation + """ + + id: str + externalId: str + content: str + jsonMetadata: Optional[Dict[str, Any]] + labels: List[AssetLabelContract] + latestLabel: AssetLabelContract + status: AssetStatus + currentStep: CurrentStepContract + isHoneypot: bool + skipped: bool + createdAt: str + updatedAt: str + consensusMark: Optional[float] + honeypotMark: Optional[float] + inferenceMark: Optional[float] + + +def validate_asset(data: Dict[str, Any]) -> AssetContract: + """Validate and return an asset contract. + + Args: + data: Dictionary to validate as an AssetContract + + Returns: + The validated data as an AssetContract + + Raises: + TypeError: If the data does not match the AssetContract structure + """ + check_type(data, AssetContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class AssetView: + """Read-only view wrapper for AssetContract. + + This dataclass provides ergonomic property access to asset data + while maintaining the underlying dictionary representation. + + Example: + >>> asset_data = {"id": "123", "externalId": "asset-1", ...} + >>> view = AssetView(asset_data) + >>> print(view.id) + '123' + >>> print(view.display_name) + 'asset-1' + """ + + __slots__ = ("_data",) + + _data: AssetContract + + @property + def id(self) -> str: + """Get asset ID.""" + return self._data.get("id", "") + + @property + def external_id(self) -> str: + """Get external ID.""" + return self._data.get("externalId", "") + + @property + def content(self) -> str: + """Get asset content.""" + return self._data.get("content", "") + + @property + def metadata(self) -> Optional[Dict[str, Any]]: + """Get JSON metadata.""" + return self._data.get("jsonMetadata") + + @property + def labels(self) -> List[AssetLabelContract]: + """Get list of labels.""" + return self._data.get("labels", []) + + @property + def latest_label(self) -> Optional[AssetLabelContract]: + """Get latest label.""" + return self._data.get("latestLabel") + + @property + def status(self) -> Optional[AssetStatus]: + """Get asset status (workflow v1).""" + return self._data.get("status") + + @property + def current_step(self) -> Optional[CurrentStepContract]: + """Get current workflow step (workflow v2).""" + return self._data.get("currentStep") + + @property + def is_honeypot(self) -> bool: + """Check if asset is a honeypot.""" + return self._data.get("isHoneypot", False) + + @property + def skipped(self) -> bool: + """Check if asset is skipped.""" + return self._data.get("skipped", False) + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def display_name(self) -> str: + """Get a human-readable display name for the asset. + + Returns the external ID if available, otherwise falls back to the ID. + """ + return self.external_id or self.id + + @property + def has_labels(self) -> bool: + """Check if asset has any labels.""" + return len(self.labels) > 0 + + @property + def label_count(self) -> int: + """Get the number of labels.""" + return len(self.labels) + + def to_dict(self) -> AssetContract: + """Get the underlying dictionary representation.""" + return self._data + + +class WorkflowStepResponseContract(TypedDict, total=False): + """Response from workflow step operations. + + Contains the project ID and list of affected asset IDs after a workflow + step operation like invalidating or advancing to the next step. + """ + + id: str + asset_ids: List[str] + + +def validate_workflow_step_response(data: Dict[str, Any]) -> WorkflowStepResponseContract: + """Validate and return a workflow step response contract. + + Args: + data: Dictionary to validate as a WorkflowStepResponseContract + + Returns: + The validated data as a WorkflowStepResponseContract + + Raises: + TypeError: If the data does not match the WorkflowStepResponseContract structure + """ + check_type(data, WorkflowStepResponseContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class WorkflowStepResponse: + """Response for workflow step operations. + + Provides typed access to the results of workflow step operations like + invalidating assets or moving them to the next step. + + Example: + >>> response = WorkflowStepResponse({"id": "project_123", "asset_ids": ["asset_1", "asset_2"]}) + >>> print(response.id) + 'project_123' + >>> print(response.asset_ids) + ['asset_1', 'asset_2'] + """ + + __slots__ = ("_data",) + + _data: Union[WorkflowStepResponseContract, Dict[str, Any]] + + @property + def id(self) -> str: + """Get the project ID.""" + return str(self._data.get("id", "")) + + @property + def asset_ids(self) -> List[str]: + """Get the list of affected asset IDs.""" + return self._data.get("asset_ids", []) + + def to_dict(self) -> WorkflowStepResponseContract: + """Get the underlying dictionary representation.""" + return cast(WorkflowStepResponseContract, self._data) + + +class AssetCreateResponseContract(TypedDict, total=False): + """Response from asset creation operations. + + Contains the project ID and list of created asset IDs. + """ + + id: str + asset_ids: List[str] + + +def validate_asset_create_response(data: Dict[str, Any]) -> AssetCreateResponseContract: + """Validate and return an asset creation response contract. + + Args: + data: Dictionary to validate as an AssetCreateResponseContract + + Returns: + The validated data as an AssetCreateResponseContract + + Raises: + TypeError: If the data does not match the AssetCreateResponseContract structure + """ + check_type(data, AssetCreateResponseContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class AssetCreateResponse: + """Response for asset creation with project ID and created asset IDs. + + Example: + >>> response = AssetCreateResponse({"id": "project_123", "asset_ids": ["asset_1", "asset_2"]}) + >>> print(response.id) + 'project_123' + >>> print(response.asset_ids) + ['asset_1', 'asset_2'] + """ + + __slots__ = ("_data",) + + _data: Union[AssetCreateResponseContract, Dict[str, Any]] + + @property + def id(self) -> str: + """Get the project ID.""" + return str(self._data.get("id", "")) + + @property + def asset_ids(self) -> List[str]: + """Get the list of created asset IDs.""" + return self._data.get("asset_ids", []) + + def to_dict(self) -> AssetCreateResponseContract: + """Get the underlying dictionary representation.""" + return cast(AssetCreateResponseContract, self._data) diff --git a/src/kili/domain_v2/connection.py b/src/kili/domain_v2/connection.py new file mode 100644 index 000000000..d54a565c0 --- /dev/null +++ b/src/kili/domain_v2/connection.py @@ -0,0 +1,110 @@ +"""Connection domain contract using TypedDict. + +This module provides a TypedDict-based contract for Connection entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, TypedDict + +from typeguard import check_type + + +class ConnectionContract(TypedDict, total=False): + """TypedDict contract for Connection entities. + + This contract represents the core structure of a Connection as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the connection + projectId: ID of the project this connection belongs to + lastChecked: ISO timestamp of last synchronization check + numberOfAssets: Number of assets in this connection + selectedFolders: List of folder paths selected for synchronization + """ + + id: str + projectId: str + lastChecked: Optional[str] + numberOfAssets: int + selectedFolders: List[str] + + +def validate_connection(data: Dict[str, Any]) -> ConnectionContract: + """Validate and return a connection contract. + + Args: + data: Dictionary to validate as a ConnectionContract + + Returns: + The validated data as a ConnectionContract + + Raises: + TypeError: If the data does not match the ConnectionContract structure + """ + check_type(data, ConnectionContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class ConnectionView: + """Read-only view wrapper for ConnectionContract. + + This dataclass provides ergonomic property access to connection data + while maintaining the underlying dictionary representation. + + Example: + >>> connection_data = {"id": "123", "projectId": "proj_456", ...} + >>> view = ConnectionView(connection_data) + >>> print(view.id) + '123' + >>> print(view.number_of_assets) + 100 + """ + + __slots__ = ("_data",) + + _data: ConnectionContract + + @property + def id(self) -> str: + """Get connection ID.""" + return self._data.get("id", "") + + @property + def project_id(self) -> str: + """Get project ID.""" + return self._data.get("projectId", "") + + @property + def last_checked(self) -> Optional[str]: + """Get last synchronization timestamp.""" + return self._data.get("lastChecked") + + @property + def number_of_assets(self) -> int: + """Get number of assets.""" + return self._data.get("numberOfAssets", 0) + + @property + def selected_folders(self) -> List[str]: + """Get list of selected folders.""" + return self._data.get("selectedFolders", []) + + @property + def folder_count(self) -> int: + """Get number of selected folders.""" + return len(self.selected_folders) + + @property + def display_name(self) -> str: + """Get a human-readable display name for the connection. + + Returns the connection ID. + """ + return self.id + + def to_dict(self) -> ConnectionContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/integration.py b/src/kili/domain_v2/integration.py new file mode 100644 index 000000000..1bc7d791e --- /dev/null +++ b/src/kili/domain_v2/integration.py @@ -0,0 +1,129 @@ +"""Integration domain contract using TypedDict. + +This module provides a TypedDict-based contract for Integration entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional, TypedDict + +from typeguard import check_type + +# Types from domain/cloud_storage.py +DataIntegrationPlatform = Literal["AWS", "AZURE", "GCP", "S3"] +DataIntegrationStatus = Literal["CONNECTED", "CHECKING", "ERROR"] + + +class IntegrationContract(TypedDict, total=False): + """TypedDict contract for Integration entities. + + This contract represents the core structure of an Integration as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the integration + name: Integration name + platform: Cloud platform (AWS, AZURE, GCP, S3) + status: Connection status (CONNECTED, CHECKING, ERROR) + organizationId: ID of the organization owning the integration + """ + + id: str + name: str + platform: DataIntegrationPlatform + status: DataIntegrationStatus + organizationId: str + + +def validate_integration(data: Dict[str, Any]) -> IntegrationContract: + """Validate and return an integration contract. + + Args: + data: Dictionary to validate as an IntegrationContract + + Returns: + The validated data as an IntegrationContract + + Raises: + TypeError: If the data does not match the IntegrationContract structure + """ + check_type(data, IntegrationContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class IntegrationView: + """Read-only view wrapper for IntegrationContract. + + This dataclass provides ergonomic property access to integration data + while maintaining the underlying dictionary representation. + + Example: + >>> integration_data = {"id": "123", "name": "My S3 Bucket", "platform": "AWS", ...} + >>> view = IntegrationView(integration_data) + >>> print(view.id) + '123' + >>> print(view.is_connected) + True + """ + + __slots__ = ("_data",) + + _data: IntegrationContract + + @property + def id(self) -> str: + """Get integration ID.""" + return self._data.get("id", "") + + @property + def name(self) -> str: + """Get integration name.""" + return self._data.get("name", "") + + @property + def platform(self) -> Optional[DataIntegrationPlatform]: + """Get cloud platform.""" + return self._data.get("platform") + + @property + def status(self) -> Optional[DataIntegrationStatus]: + """Get connection status.""" + return self._data.get("status") + + @property + def organization_id(self) -> str: + """Get organization ID.""" + return self._data.get("organizationId", "") + + @property + def is_connected(self) -> bool: + """Check if integration is connected.""" + return self.status == "CONNECTED" + + @property + def is_checking(self) -> bool: + """Check if integration is being verified.""" + return self.status == "CHECKING" + + @property + def has_error(self) -> bool: + """Check if integration has an error.""" + return self.status == "ERROR" + + @property + def is_active(self) -> bool: + """Alias for is_connected.""" + return self.is_connected + + @property + def display_name(self) -> str: + """Get a human-readable display name for the integration. + + Returns the integration name. + """ + return self.name or self.id + + def to_dict(self) -> IntegrationContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/issue.py b/src/kili/domain_v2/issue.py new file mode 100644 index 000000000..c1c624297 --- /dev/null +++ b/src/kili/domain_v2/issue.py @@ -0,0 +1,138 @@ +"""Issue domain contract using TypedDict. + +This module provides a TypedDict-based contract for Issue entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional, TypedDict + +from typeguard import check_type + +# Issue types from domain/issue.py +IssueType = Literal["ISSUE", "QUESTION"] +IssueStatus = Literal["CANCELLED", "OPEN", "SOLVED"] + + +class IssueContract(TypedDict, total=False): + """TypedDict contract for Issue entities. + + This contract represents the core structure of an Issue as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the issue + createdAt: ISO timestamp of creation + status: Current status (CANCELLED, OPEN, SOLVED) + type: Type of issue (ISSUE or QUESTION) + assetId: ID of the asset this issue is related to + hasBeenSeen: Whether the issue has been viewed + objectType: Type of the object (always "Issue") + """ + + id: str + createdAt: str + status: IssueStatus + type: IssueType + assetId: str + hasBeenSeen: bool + objectType: str + + +def validate_issue(data: Dict[str, Any]) -> IssueContract: + """Validate and return an issue contract. + + Args: + data: Dictionary to validate as an IssueContract + + Returns: + The validated data as an IssueContract + + Raises: + TypeError: If the data does not match the IssueContract structure + """ + check_type(data, IssueContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class IssueView: + """Read-only view wrapper for IssueContract. + + This dataclass provides ergonomic property access to issue data + while maintaining the underlying dictionary representation. + + Example: + >>> issue_data = {"id": "123", "type": "ISSUE", "status": "OPEN", ...} + >>> view = IssueView(issue_data) + >>> print(view.id) + '123' + >>> print(view.is_open) + True + """ + + __slots__ = ("_data",) + + _data: IssueContract + + @property + def id(self) -> str: + """Get issue ID.""" + return self._data.get("id", "") + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def status(self) -> Optional[IssueStatus]: + """Get issue status.""" + return self._data.get("status") + + @property + def type(self) -> Optional[IssueType]: + """Get issue type.""" + return self._data.get("type") + + @property + def asset_id(self) -> str: + """Get related asset ID.""" + return self._data.get("assetId", "") + + @property + def has_been_seen(self) -> bool: + """Check if issue has been viewed.""" + return self._data.get("hasBeenSeen", False) + + @property + def is_open(self) -> bool: + """Check if issue is open.""" + return self.status == "OPEN" + + @property + def is_solved(self) -> bool: + """Check if issue is solved.""" + return self.status == "SOLVED" + + @property + def is_cancelled(self) -> bool: + """Check if issue is cancelled.""" + return self.status == "CANCELLED" + + @property + def is_question(self) -> bool: + """Check if this is a question.""" + return self.type == "QUESTION" + + @property + def display_name(self) -> str: + """Get a human-readable display name for the issue. + + Returns the issue ID. + """ + return self.id + + def to_dict(self) -> IssueContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/label.py b/src/kili/domain_v2/label.py new file mode 100644 index 000000000..2b46b7dd7 --- /dev/null +++ b/src/kili/domain_v2/label.py @@ -0,0 +1,242 @@ +"""Label domain contract using TypedDict. + +This module provides a TypedDict-based contract for Label entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union + +from typeguard import check_type + +from .user import UserContract + +# Label type from domain/label.py +LabelType = Literal["AUTOSAVE", "DEFAULT", "INFERENCE", "PREDICTION", "REVIEW"] + + +class LabelContract(TypedDict, total=False): + """TypedDict contract for Label entities. + + This contract represents the core structure of a Label as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the label + author: User who created the label + jsonResponse: The actual label data/annotations + createdAt: ISO timestamp of creation + labelType: Type of label (DEFAULT, REVIEW, etc.) + modelName: Name of model used (for predictions) + secondsToLabel: Time spent labeling in seconds + isLatestLabelForUser: Whether this is the user's latest label + consensusMark: Consensus quality score + honeypotMark: Honeypot quality score + inferenceMark: Inference quality score + """ + + id: str + author: UserContract + jsonResponse: Dict[str, Any] + createdAt: str + updatedAt: str + labelType: LabelType + modelName: Optional[str] + secondsToLabel: Optional[int] + isLatestLabelForUser: bool + isLatestDefaultLabelForUser: bool + consensusMark: Optional[float] + honeypotMark: Optional[float] + inferenceMark: Optional[float] + skipped: bool + totalSecondsToLabel: Optional[int] + numberOfAnnotations: Optional[int] + + +def validate_label(data: Dict[str, Any]) -> LabelContract: + """Validate and return a label contract. + + Args: + data: Dictionary to validate as a LabelContract + + Returns: + The validated data as a LabelContract + + Raises: + TypeError: If the data does not match the LabelContract structure + """ + check_type(data, LabelContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class LabelView: + """Read-only view wrapper for LabelContract. + + This dataclass provides ergonomic property access to label data + while maintaining the underlying dictionary representation. + + Example: + >>> label_data = {"id": "456", "author": {"email": "user@example.com"}, ...} + >>> view = LabelView(label_data) + >>> print(view.id) + '456' + >>> print(view.author_email) + 'user@example.com' + """ + + __slots__ = ("_data",) + + _data: LabelContract + + @property + def id(self) -> str: + """Get label ID.""" + return self._data.get("id", "") + + @property + def author(self) -> Optional[UserContract]: + """Get label author.""" + return self._data.get("author") + + @property + def author_email(self) -> str: + """Get author email.""" + author = self.author + return author.get("email", "") if author else "" + + @property + def author_id(self) -> str: + """Get author ID.""" + author = self.author + return author.get("id", "") if author else "" + + @property + def json_response(self) -> Dict[str, Any]: + """Get JSON response (annotation data).""" + return self._data.get("jsonResponse", {}) + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def updated_at(self) -> Optional[str]: + """Get update timestamp.""" + return self._data.get("updatedAt") + + @property + def label_type(self) -> Optional[LabelType]: + """Get label type.""" + return self._data.get("labelType") + + @property + def model_name(self) -> Optional[str]: + """Get model name (for predictions).""" + return self._data.get("modelName") + + @property + def seconds_to_label(self) -> Optional[int]: + """Get seconds spent labeling.""" + return self._data.get("secondsToLabel") + + @property + def is_latest(self) -> bool: + """Check if this is the latest label for the user.""" + return self._data.get("isLatestLabelForUser", False) + + @property + def consensus_mark(self) -> Optional[float]: + """Get consensus quality mark.""" + return self._data.get("consensusMark") + + @property + def honeypot_mark(self) -> Optional[float]: + """Get honeypot quality mark.""" + return self._data.get("honeypotMark") + + @property + def is_prediction(self) -> bool: + """Check if label is a prediction.""" + return self.label_type in ("PREDICTION", "INFERENCE") + + @property + def is_review(self) -> bool: + """Check if label is a review.""" + return self.label_type == "REVIEW" + + @property + def display_name(self) -> str: + """Get a human-readable display name for the label. + + Returns author email if available, otherwise the label ID. + """ + return self.author_email or self.id + + def to_dict(self) -> LabelContract: + """Get the underlying dictionary representation.""" + return self._data + + +def sort_labels_by_created_at( + labels: List[LabelContract], reverse: bool = False +) -> List[LabelContract]: + """Sort labels by creation timestamp. + + Args: + labels: List of label contracts to sort + reverse: If True, sort in descending order (newest first) + + Returns: + Sorted list of labels + """ + return sorted( + labels, + key=lambda label: label.get("createdAt", ""), + reverse=reverse, + ) + + +def filter_labels_by_type( + labels: List[LabelContract], label_type: LabelType +) -> List[LabelContract]: + """Filter labels by type. + + Args: + labels: List of label contracts to filter + label_type: Label type to filter by + + Returns: + Filtered list of labels + """ + return [label for label in labels if label.get("labelType") == label_type] + + +@dataclass(frozen=True) +class LabelExportResponse: + """Response for label export operations. + + This wraps the export metadata returned when exporting labels without + saving to a file. The structure contains export information about + the processed data. + + Example: + >>> export_result = kili.labels.export(project_id="proj_123", filename=None, fmt="kili") + >>> if export_result: + ... for item in export_result.export_info: + ... print(item) + """ + + __slots__ = ("_data",) + + _data: List[Dict[str, Union[List[str], str]]] + + @property + def export_info(self) -> List[Dict[str, Union[List[str], str]]]: + """Get export information.""" + return self._data + + def to_list(self) -> List[Dict[str, Union[List[str], str]]]: + """Get the underlying list representation.""" + return self._data diff --git a/src/kili/domain_v2/notification.py b/src/kili/domain_v2/notification.py new file mode 100644 index 000000000..0ae4d1a8f --- /dev/null +++ b/src/kili/domain_v2/notification.py @@ -0,0 +1,126 @@ +"""Notification domain contract using TypedDict. + +This module provides a TypedDict-based contract for Notification entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, TypedDict + +from typeguard import check_type + + +class NotificationContract(TypedDict, total=False): + """TypedDict contract for Notification entities. + + This contract represents the core structure of a Notification as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the notification + createdAt: ISO timestamp of creation + message: Notification message content + status: Notification status/priority + userID: ID of the user receiving the notification + url: Optional URL for the notification action + hasBeenSeen: Whether the notification has been viewed + """ + + id: str + createdAt: str + message: str + status: str + userID: str + url: Optional[str] + hasBeenSeen: bool + + +def validate_notification(data: Dict[str, Any]) -> NotificationContract: + """Validate and return a notification contract. + + Args: + data: Dictionary to validate as a NotificationContract + + Returns: + The validated data as a NotificationContract + + Raises: + TypeError: If the data does not match the NotificationContract structure + """ + check_type(data, NotificationContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class NotificationView: + """Read-only view wrapper for NotificationContract. + + This dataclass provides ergonomic property access to notification data + while maintaining the underlying dictionary representation. + + Example: + >>> notification_data = {"id": "123", "message": "Task completed", ...} + >>> view = NotificationView(notification_data) + >>> print(view.id) + '123' + >>> print(view.has_been_seen) + False + """ + + __slots__ = ("_data",) + + _data: NotificationContract + + @property + def id(self) -> str: + """Get notification ID.""" + return self._data.get("id", "") + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def message(self) -> str: + """Get notification message.""" + return self._data.get("message", "") + + @property + def status(self) -> str: + """Get notification status.""" + return self._data.get("status", "") + + @property + def user_id(self) -> str: + """Get user ID.""" + return self._data.get("userID", "") + + @property + def url(self) -> Optional[str]: + """Get notification URL.""" + return self._data.get("url") + + @property + def has_been_seen(self) -> bool: + """Check if notification has been seen.""" + return self._data.get("hasBeenSeen", False) + + @property + def is_unread(self) -> bool: + """Check if notification is unread.""" + return not self.has_been_seen + + @property + def display_name(self) -> str: + """Get a human-readable display name for the notification. + + Returns a truncated version of the message. + """ + if len(self.message) > 50: + return self.message[:47] + "..." + return self.message or self.id + + def to_dict(self) -> NotificationContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/organization.py b/src/kili/domain_v2/organization.py new file mode 100644 index 000000000..17ea20fbd --- /dev/null +++ b/src/kili/domain_v2/organization.py @@ -0,0 +1,222 @@ +"""Organization domain contract using TypedDict. + +This module provides a TypedDict-based contract for Organization entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, TypedDict + +from typeguard import check_type + + +class OrganizationContract(TypedDict, total=False): + """TypedDict contract for Organization entities. + + This contract represents the core structure of an Organization as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the organization + name: Organization name + address: Physical address + city: City location + country: Country location + zipCode: Postal/ZIP code + numberOfAnnotations: Total number of annotations + numberOfLabeledAssets: Total number of labeled assets + numberOfHours: Total hours spent on annotations + """ + + id: str + name: str + address: Optional[str] + city: Optional[str] + country: Optional[str] + zipCode: Optional[str] + numberOfAnnotations: int + numberOfLabeledAssets: int + numberOfHours: float + + +def validate_organization(data: Dict[str, Any]) -> OrganizationContract: + """Validate and return an organization contract. + + Args: + data: Dictionary to validate as an OrganizationContract + + Returns: + The validated data as an OrganizationContract + + Raises: + TypeError: If the data does not match the OrganizationContract structure + """ + check_type(data, OrganizationContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class OrganizationView: + """Read-only view wrapper for OrganizationContract. + + This dataclass provides ergonomic property access to organization data + while maintaining the underlying dictionary representation. + + Example: + >>> org_data = {"id": "123", "name": "Acme Corp", ...} + >>> view = OrganizationView(org_data) + >>> print(view.id) + '123' + >>> print(view.display_name) + 'Acme Corp' + """ + + __slots__ = ("_data",) + + _data: OrganizationContract + + @property + def id(self) -> str: + """Get organization ID.""" + return self._data.get("id", "") + + @property + def name(self) -> str: + """Get organization name.""" + return self._data.get("name", "") + + @property + def address(self) -> Optional[str]: + """Get organization address.""" + return self._data.get("address") + + @property + def city(self) -> Optional[str]: + """Get organization city.""" + return self._data.get("city") + + @property + def country(self) -> Optional[str]: + """Get organization country.""" + return self._data.get("country") + + @property + def zip_code(self) -> Optional[str]: + """Get organization ZIP code.""" + return self._data.get("zipCode") + + @property + def number_of_annotations(self) -> int: + """Get total number of annotations.""" + return self._data.get("numberOfAnnotations", 0) + + @property + def number_of_labeled_assets(self) -> int: + """Get total number of labeled assets.""" + return self._data.get("numberOfLabeledAssets", 0) + + @property + def number_of_hours(self) -> float: + """Get total hours spent on annotations.""" + return self._data.get("numberOfHours", 0.0) + + @property + def display_name(self) -> str: + """Get a human-readable display name for the organization. + + Returns the organization name. + """ + return self.name or self.id + + @property + def full_address(self) -> str: + """Get full formatted address. + + Returns: + Formatted address string combining address, city, and country + """ + parts = [] + if self.address: + parts.append(self.address) + if self.city: + parts.append(self.city) + if self.country: + parts.append(self.country) + return ", ".join(parts) if parts else "" + + def to_dict(self) -> OrganizationContract: + """Get the underlying dictionary representation.""" + return self._data + + +class OrganizationMetricsContract(TypedDict, total=False): + """TypedDict contract for Organization Metrics. + + This contract represents organization-level metrics including + annotation counts, hours spent, and labeled assets. + + Fields: + numberOfAnnotations: Total number of annotations + numberOfHours: Total hours spent on labeling + numberOfLabeledAssets: Total number of labeled assets + """ + + numberOfAnnotations: int + numberOfHours: float + numberOfLabeledAssets: int + + +def validate_organization_metrics(data: Dict[str, Any]) -> OrganizationMetricsContract: + """Validate and return an organization metrics contract. + + Args: + data: Dictionary to validate as an OrganizationMetricsContract + + Returns: + The validated data as an OrganizationMetricsContract + + Raises: + TypeError: If the data does not match the OrganizationMetricsContract structure + """ + check_type(data, OrganizationMetricsContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class OrganizationMetricsView: + """Read-only view wrapper for OrganizationMetricsContract. + + This dataclass provides ergonomic property access to organization metrics + while maintaining the underlying dictionary representation. + + Example: + >>> metrics_data = {"numberOfAnnotations": 1000, "numberOfHours": 42.5, ...} + >>> view = OrganizationMetricsView(metrics_data) + >>> print(view.number_of_annotations) + 1000 + >>> print(view.number_of_hours) + 42.5 + """ + + __slots__ = ("_data",) + + _data: OrganizationMetricsContract + + @property + def number_of_annotations(self) -> int: + """Get total number of annotations.""" + return self._data.get("numberOfAnnotations", 0) + + @property + def number_of_hours(self) -> float: + """Get total hours spent on labeling.""" + return self._data.get("numberOfHours", 0.0) + + @property + def number_of_labeled_assets(self) -> int: + """Get total number of labeled assets.""" + return self._data.get("numberOfLabeledAssets", 0) + + def to_dict(self) -> OrganizationMetricsContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/plugin.py b/src/kili/domain_v2/plugin.py new file mode 100644 index 000000000..157389782 --- /dev/null +++ b/src/kili/domain_v2/plugin.py @@ -0,0 +1,117 @@ +"""Plugin domain contract using TypedDict. + +This module provides a TypedDict-based contract for Plugin entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, TypedDict + +from typeguard import check_type + + +class PluginContract(TypedDict, total=False): + """TypedDict contract for Plugin entities. + + This contract represents the core structure of a Plugin as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the plugin + name: Plugin name + createdAt: ISO timestamp of creation + pluginType: Type of plugin (e.g., webhook, function) + verbose: Whether verbose logging is enabled + isActivated: Whether the plugin is currently active + """ + + id: str + name: str + createdAt: str + pluginType: str + verbose: bool + isActivated: bool + + +def validate_plugin(data: Dict[str, Any]) -> PluginContract: + """Validate and return a plugin contract. + + Args: + data: Dictionary to validate as a PluginContract + + Returns: + The validated data as a PluginContract + + Raises: + TypeError: If the data does not match the PluginContract structure + """ + check_type(data, PluginContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class PluginView: + """Read-only view wrapper for PluginContract. + + This dataclass provides ergonomic property access to plugin data + while maintaining the underlying dictionary representation. + + Example: + >>> plugin_data = {"id": "123", "name": "My Webhook", "isActivated": True, ...} + >>> view = PluginView(plugin_data) + >>> print(view.id) + '123' + >>> print(view.is_activated) + True + """ + + __slots__ = ("_data",) + + _data: PluginContract + + @property + def id(self) -> str: + """Get plugin ID.""" + return self._data.get("id", "") + + @property + def name(self) -> str: + """Get plugin name.""" + return self._data.get("name", "") + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def plugin_type(self) -> str: + """Get plugin type.""" + return self._data.get("pluginType", "") + + @property + def verbose(self) -> bool: + """Check if verbose logging is enabled.""" + return self._data.get("verbose", False) + + @property + def is_activated(self) -> bool: + """Check if plugin is activated.""" + return self._data.get("isActivated", False) + + @property + def is_active(self) -> bool: + """Alias for is_activated.""" + return self.is_activated + + @property + def display_name(self) -> str: + """Get a human-readable display name for the plugin. + + Returns the plugin name. + """ + return self.name or self.id + + def to_dict(self) -> PluginContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/project.py b/src/kili/domain_v2/project.py new file mode 100644 index 000000000..6fac839be --- /dev/null +++ b/src/kili/domain_v2/project.py @@ -0,0 +1,602 @@ +"""Project domain contract using TypedDict. + +This module provides a TypedDict-based contract for Project entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast + +from typeguard import check_type + +# Types from domain/project.py +InputType = Literal[ + "IMAGE", "GEOSPATIAL", "PDF", "TEXT", "VIDEO", "LLM_RLHF", "LLM_INSTR_FOLLOWING", "LLM_STATIC" +] +WorkflowVersion = Literal["V1", "V2"] +ComplianceTag = Literal["PHI", "PII"] + + +class ProjectStepContract(TypedDict, total=False): + """Project workflow step information.""" + + id: str + name: str + type: Literal["DEFAULT", "REVIEW"] + order: int + stepCoverage: Optional[int] + consensusCoverage: Optional[int] + numberOfExpectedLabelsForConsensus: Optional[int] + + +class ProjectRoleContract(TypedDict, total=False): + """Project role information.""" + + id: str + role: Literal["ADMIN", "TEAM_MANAGER", "LABELER", "REVIEWER"] + user: Dict[str, Any] + + +class ProjectContract(TypedDict, total=False): + """TypedDict contract for Project entities. + + This contract represents the core structure of a Project as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the project + title: Project title/name + description: Project description + inputType: Type of input data (IMAGE, TEXT, etc.) + jsonInterface: Ontology/interface definition + workflowVersion: Workflow version (V1 or V2) + steps: Workflow steps (for V2 projects) + roles: User roles in the project + numberOfAssets: Total number of assets + createdAt: ISO timestamp of creation + archived: Whether project is archived + """ + + id: str + title: str + description: str + inputType: InputType + jsonInterface: Dict[str, Any] + workflowVersion: WorkflowVersion + steps: List[ProjectStepContract] + roles: List[ProjectRoleContract] + numberOfAssets: int + numberOfSkippedAssets: int + numberOfRemainingAssets: int + numberOfReviewedAssets: int + numberOfLatestLabels: int + createdAt: str + updatedAt: str + archived: bool + starred: bool + complianceTags: List[ComplianceTag] + consensusTotCoverage: Optional[int] + minConsensusSize: Optional[int] + useHoneypot: bool + consensusMark: Optional[float] + honeypotMark: Optional[float] + reviewCoverage: Optional[int] + shouldRelaunchKpiComputation: bool + readPermissionsForAssetsAndLabels: bool + + +def validate_project(data: Dict[str, Any]) -> ProjectContract: + """Validate and return a project contract. + + Args: + data: Dictionary to validate as a ProjectContract + + Returns: + The validated data as a ProjectContract + + Raises: + TypeError: If the data does not match the ProjectContract structure + """ + check_type(data, ProjectContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class ProjectView: + """Read-only view wrapper for ProjectContract. + + This dataclass provides ergonomic property access to project data + while maintaining the underlying dictionary representation. + + Example: + >>> project_data = {"id": "789", "title": "My Project", ...} + >>> view = ProjectView(project_data) + >>> print(view.id) + '789' + >>> print(view.display_name) + 'My Project' + """ + + __slots__ = ("_data",) + + _data: ProjectContract + + @property + def id(self) -> str: + """Get project ID.""" + return self._data.get("id", "") + + @property + def title(self) -> str: + """Get project title.""" + return self._data.get("title", "") + + @property + def description(self) -> str: + """Get project description.""" + return self._data.get("description", "") + + @property + def input_type(self) -> Optional[InputType]: + """Get input type.""" + return self._data.get("inputType") + + @property + def json_interface(self) -> Dict[str, Any]: + """Get JSON interface (ontology).""" + return self._data.get("jsonInterface", {}) + + @property + def workflow_version(self) -> Optional[WorkflowVersion]: + """Get workflow version.""" + return self._data.get("workflowVersion") + + @property + def steps(self) -> List[ProjectStepContract]: + """Get workflow steps.""" + return self._data.get("steps", []) + + @property + def roles(self) -> List[ProjectRoleContract]: + """Get project roles.""" + return self._data.get("roles", []) + + @property + def number_of_assets(self) -> int: + """Get total number of assets.""" + return self._data.get("numberOfAssets", 0) + + @property + def number_of_remaining_assets(self) -> int: + """Get number of remaining assets.""" + return self._data.get("numberOfRemainingAssets", 0) + + @property + def number_of_reviewed_assets(self) -> int: + """Get number of reviewed assets.""" + return self._data.get("numberOfReviewedAssets", 0) + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def updated_at(self) -> Optional[str]: + """Get update timestamp.""" + return self._data.get("updatedAt") + + @property + def archived(self) -> bool: + """Check if project is archived.""" + return self._data.get("archived", False) + + @property + def starred(self) -> bool: + """Check if project is starred.""" + return self._data.get("starred", False) + + @property + def is_v2_workflow(self) -> bool: + """Check if project uses workflow V2.""" + return self.workflow_version == "V2" + + @property + def has_honeypot(self) -> bool: + """Check if project uses honeypot assets.""" + return self._data.get("useHoneypot", False) + + @property + def display_name(self) -> str: + """Get a human-readable display name for the project. + + Returns the title. + """ + return self.title or self.id + + @property + def progress_percentage(self) -> float: + """Calculate project completion percentage. + + Returns: + Percentage of completed assets (0-100) + """ + total = self.number_of_assets + if total == 0: + return 0.0 + remaining = self.number_of_remaining_assets + completed = total - remaining + return (completed / total) * 100 + + def to_dict(self) -> ProjectContract: + """Get the underlying dictionary representation.""" + return self._data + + +def get_step_by_name(project: ProjectContract, step_name: str) -> Optional[ProjectStepContract]: + """Get a workflow step by name. + + Args: + project: Project contract + step_name: Name of the step to find + + Returns: + Step contract if found, None otherwise + """ + steps = project.get("steps", []) + for step in steps: + if step.get("name") == step_name: + return step + return None + + +def get_ordered_steps(project: ProjectContract) -> List[ProjectStepContract]: + """Get workflow steps ordered by their order field. + + Args: + project: Project contract + + Returns: + List of steps sorted by order + """ + steps = project.get("steps", []) + return sorted(steps, key=lambda s: s.get("order", 0)) + + +class ProjectVersionContract(TypedDict, total=False): + """TypedDict contract for ProjectVersion entities. + + This contract represents the structure of a project version + as returned from the Kili API. + + Key fields: + id: Unique identifier for the version + name: Version name + content: Link to download the version + createdAt: ISO timestamp of creation + projectId: ID of the associated project + """ + + id: str + name: str + content: Optional[str] + createdAt: str + projectId: str + + +def validate_project_version(data: Dict[str, Any]) -> ProjectVersionContract: + """Validate and return a project version contract. + + Args: + data: Dictionary to validate as a ProjectVersionContract + + Returns: + The validated data as a ProjectVersionContract + + Raises: + TypeError: If the data does not match the ProjectVersionContract structure + """ + check_type(data, ProjectVersionContract) + return data # type: ignore[return-value] + + +def validate_project_role(data: Dict[str, Any]) -> ProjectRoleContract: + """Validate and return a project role contract. + + Args: + data: Dictionary to validate as a ProjectRoleContract + + Returns: + The validated data as a ProjectRoleContract + + Raises: + TypeError: If the data does not match the ProjectRoleContract structure + """ + check_type(data, ProjectRoleContract) + return data # type: ignore[return-value] + + +def validate_workflow_step(data: Dict[str, Any]) -> ProjectStepContract: + """Validate and return a workflow step contract. + + Args: + data: Dictionary to validate as a ProjectStepContract + + Returns: + The validated data as a ProjectStepContract + + Raises: + TypeError: If the data does not match the ProjectStepContract structure + """ + check_type(data, ProjectStepContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class IdResponse: + """Simple response containing only an ID. + + This is used for mutation operations that return just an ID, + providing a typed alternative to Dict[Literal["id"], str]. + + Example: + >>> response = IdResponse({"id": "project_123"}) + >>> print(response.id) + 'project_123' + """ + + __slots__ = ("_data",) + + _data: Union[Dict[Literal["id"], str], Dict[str, Any]] + + @property + def id(self) -> str: + """Get the ID.""" + return str(self._data.get("id", "")) + + def to_dict(self) -> Dict[str, Any]: + """Get the underlying dictionary representation.""" + return cast(Dict[str, Any], self._data) + + +@dataclass(frozen=True) +class IdListResponse: + """Response containing a list of IDs. + + This is used for mutation operations that return multiple IDs, + providing a typed alternative to List[Dict[Literal["id"], str]]. + + Example: + >>> response = IdListResponse([{"id": "item_1"}, {"id": "item_2"}]) + >>> print(response.ids) + ['item_1', 'item_2'] + """ + + __slots__ = ("_data",) + + _data: Union[List[Dict[Literal["id"], str]], List[Dict[str, Any]]] + + @property + def ids(self) -> List[str]: + """Get the list of IDs.""" + return [str(item.get("id", "")) for item in self._data] + + def to_list(self) -> List[Dict[str, Any]]: + """Get the underlying list representation.""" + return cast(List[Dict[str, Any]], self._data) + + +@dataclass(frozen=True) +class StatusResponse: + """Response containing operation status and result. + + This is used for operations that return status information about an action. + + Example: + >>> response = StatusResponse({"id": "item_1", "status": "SOLVED", "success": True}) + >>> print(response.success) + True + """ + + __slots__ = ("_data",) + + _data: Dict[str, Any] + + @property + def id(self) -> str: + """Get the ID of the affected item.""" + return str(self._data.get("id", "")) + + @property + def success(self) -> bool: + """Check if the operation was successful.""" + return bool(self._data.get("success", False)) + + @property + def status(self) -> Optional[str]: + """Get the status value if present.""" + return self._data.get("status") + + @property + def error(self) -> Optional[str]: + """Get the error message if operation failed.""" + return self._data.get("error") + + def to_dict(self) -> Dict[str, Any]: + """Get the underlying dictionary representation.""" + return self._data + + +@dataclass(frozen=True) +class ProjectRoleView: + """Read-only view wrapper for ProjectRoleContract. + + This dataclass provides ergonomic property access to project role data. + + Example: + >>> role_data = {"id": "role_1", "role": "ADMIN", "user": {"id": "user_1", "email": "admin@example.com"}} + >>> view = ProjectRoleView(role_data) + >>> print(view.role) + 'ADMIN' + >>> print(view.user_email) + 'admin@example.com' + """ + + __slots__ = ("_data",) + + _data: ProjectRoleContract + + @property + def id(self) -> str: + """Get role ID.""" + return self._data.get("id", "") + + @property + def role(self) -> Optional[Literal["ADMIN", "TEAM_MANAGER", "LABELER", "REVIEWER"]]: + """Get the role type.""" + return self._data.get("role") + + @property + def user(self) -> Dict[str, Any]: + """Get user information.""" + return self._data.get("user", {}) + + @property + def user_id(self) -> str: + """Get user ID from nested user object.""" + return self.user.get("id", "") + + @property + def user_email(self) -> str: + """Get user email from nested user object.""" + return self.user.get("email", "") + + @property + def display_name(self) -> str: + """Get a human-readable display name.""" + return f"{self.user_email} ({self.role})" if self.user_email else self.id + + def to_dict(self) -> ProjectRoleContract: + """Get the underlying dictionary representation.""" + return self._data + + +@dataclass(frozen=True) +class WorkflowStepView: + """Read-only view wrapper for ProjectStepContract. + + This dataclass provides ergonomic property access to workflow step data. + + Example: + >>> step_data = {"id": "step_1", "name": "Labeling", "type": "DEFAULT", "order": 1} + >>> view = WorkflowStepView(step_data) + >>> print(view.name) + 'Labeling' + >>> print(view.is_review_step) + False + """ + + __slots__ = ("_data",) + + _data: ProjectStepContract + + @property + def id(self) -> str: + """Get step ID.""" + return self._data.get("id", "") + + @property + def name(self) -> str: + """Get step name.""" + return self._data.get("name", "") + + @property + def type(self) -> Optional[Literal["DEFAULT", "REVIEW"]]: + """Get step type.""" + return self._data.get("type") + + @property + def order(self) -> int: + """Get step order.""" + return self._data.get("order", 0) + + @property + def step_coverage(self) -> Optional[int]: + """Get step coverage percentage.""" + return self._data.get("stepCoverage") + + @property + def consensus_coverage(self) -> Optional[int]: + """Get consensus coverage.""" + return self._data.get("consensusCoverage") + + @property + def number_of_expected_labels_for_consensus(self) -> Optional[int]: + """Get expected number of labels for consensus.""" + return self._data.get("numberOfExpectedLabelsForConsensus") + + @property + def is_review_step(self) -> bool: + """Check if this is a review step.""" + return self.type == "REVIEW" + + @property + def display_name(self) -> str: + """Get a human-readable display name.""" + return self.name or self.id + + def to_dict(self) -> ProjectStepContract: + """Get the underlying dictionary representation.""" + return self._data + + +@dataclass(frozen=True) +class ProjectVersionView: + """Read-only view wrapper for ProjectVersionContract. + + This dataclass provides ergonomic property access to project version data. + + Example: + >>> version_data = {"id": "v1", "name": "Version 1.0", "projectId": "proj_1"} + >>> view = ProjectVersionView(version_data) + >>> print(view.name) + 'Version 1.0' + """ + + __slots__ = ("_data",) + + _data: ProjectVersionContract + + @property + def id(self) -> str: + """Get version ID.""" + return self._data.get("id", "") + + @property + def name(self) -> str: + """Get version name.""" + return self._data.get("name", "") + + @property + def content(self) -> Optional[str]: + """Get version content/download link.""" + return self._data.get("content") + + @property + def created_at(self) -> Optional[str]: + """Get creation timestamp.""" + return self._data.get("createdAt") + + @property + def project_id(self) -> str: + """Get associated project ID.""" + return self._data.get("projectId", "") + + @property + def display_name(self) -> str: + """Get a human-readable display name.""" + return self.name or self.id + + def to_dict(self) -> ProjectVersionContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/tag.py b/src/kili/domain_v2/tag.py new file mode 100644 index 000000000..35c41330f --- /dev/null +++ b/src/kili/domain_v2/tag.py @@ -0,0 +1,110 @@ +"""Tag domain contract using TypedDict. + +This module provides a TypedDict-based contract for Tag entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, TypedDict + +from typeguard import check_type + + +class TagContract(TypedDict, total=False): + """TypedDict contract for Tag entities. + + This contract represents the core structure of a Tag as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the tag + label: Tag label/name + color: Tag color (hex format) + organizationId: ID of the organization owning the tag + checkedForProjects: List of project IDs this tag is assigned to + """ + + id: str + label: str + color: str + organizationId: str + checkedForProjects: List[str] + + +def validate_tag(data: Dict[str, Any]) -> TagContract: + """Validate and return a tag contract. + + Args: + data: Dictionary to validate as a TagContract + + Returns: + The validated data as a TagContract + + Raises: + TypeError: If the data does not match the TagContract structure + """ + check_type(data, TagContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class TagView: + """Read-only view wrapper for TagContract. + + This dataclass provides ergonomic property access to tag data + while maintaining the underlying dictionary representation. + + Example: + >>> tag_data = {"id": "123", "label": "important", "color": "#ff0000", ...} + >>> view = TagView(tag_data) + >>> print(view.id) + '123' + >>> print(view.display_name) + 'important' + """ + + __slots__ = ("_data",) + + _data: TagContract + + @property + def id(self) -> str: + """Get tag ID.""" + return self._data.get("id", "") + + @property + def label(self) -> str: + """Get tag label.""" + return self._data.get("label", "") + + @property + def color(self) -> str: + """Get tag color.""" + return self._data.get("color", "") + + @property + def organization_id(self) -> str: + """Get organization ID.""" + return self._data.get("organizationId", "") + + @property + def checked_for_projects(self) -> List[str]: + """Get list of project IDs this tag is assigned to.""" + return self._data.get("checkedForProjects", []) + + @property + def project_count(self) -> int: + """Get number of projects this tag is assigned to.""" + return len(self.checked_for_projects) + + @property + def display_name(self) -> str: + """Get a human-readable display name for the tag. + + Returns the tag label. + """ + return self.label or self.id + + def to_dict(self) -> TagContract: + """Get the underlying dictionary representation.""" + return self._data diff --git a/src/kili/domain_v2/user.py b/src/kili/domain_v2/user.py new file mode 100644 index 000000000..a3e1fc6eb --- /dev/null +++ b/src/kili/domain_v2/user.py @@ -0,0 +1,219 @@ +"""User domain contract using TypedDict. + +This module provides a TypedDict-based contract for User entities, +along with validation utilities and helper functions. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, TypedDict, Union + +from typeguard import check_type + +# Types from domain/user.py +HubspotSubscriptionStatus = Literal["SUBSCRIBED", "UNSUBSCRIBED"] + + +class OrganizationRoleContract(TypedDict, total=False): + """Organization role information.""" + + id: str + role: Literal["ADMIN", "USER", "REVIEWER"] + + +class UserContract(TypedDict, total=False): + """TypedDict contract for User entities. + + This contract represents the core structure of a User as returned + from the Kili API. All fields are optional to allow partial data. + + Key fields: + id: Unique identifier for the user + email: User email address + name: User display name + firstname: User first name + lastname: User last name + activated: Whether the user account is activated + organizationRole: User role in organization + createdAt: ISO timestamp of account creation + """ + + id: str + email: str + name: str + firstname: str + lastname: str + activated: bool + createdAt: str + updatedAt: str + organizationId: str + organizationRole: Union[str, OrganizationRoleContract] # API can return string or dict + apiKey: str + phone: Optional[str] + hubspotSubscriptionStatus: HubspotSubscriptionStatus + lastSeenAt: Optional[str] + + +def validate_user(data: Dict[str, Any]) -> UserContract: + """Validate and return a user contract. + + Args: + data: Dictionary to validate as a UserContract + + Returns: + The validated data as a UserContract + + Raises: + TypeError: If the data does not match the UserContract structure + """ + check_type(data, UserContract) + return data # type: ignore[return-value] + + +@dataclass(frozen=True) +class UserView: + """Read-only view wrapper for UserContract. + + This dataclass provides ergonomic property access to user data + while maintaining the underlying dictionary representation. + + Example: + >>> user_data = {"id": "101", "email": "user@example.com", ...} + >>> view = UserView(user_data) + >>> print(view.id) + '101' + >>> print(view.display_name) + 'user@example.com' + """ + + __slots__ = ("_data",) + + _data: UserContract + + @property + def id(self) -> str: + """Get user ID.""" + return self._data.get("id", "") + + @property + def email(self) -> str: + """Get user email.""" + return self._data.get("email", "") + + @property + def name(self) -> str: + """Get user name.""" + return self._data.get("name", "") + + @property + def firstname(self) -> str: + """Get user first name.""" + return self._data.get("firstname", "") + + @property + def lastname(self) -> str: + """Get user last name.""" + return self._data.get("lastname", "") + + @property + def activated(self) -> bool: + """Check if user account is activated.""" + return self._data.get("activated", False) + + @property + def created_at(self) -> Optional[str]: + """Get account creation timestamp.""" + return self._data.get("createdAt") + + @property + def updated_at(self) -> Optional[str]: + """Get account update timestamp.""" + return self._data.get("updatedAt") + + @property + def organization_id(self) -> str: + """Get organization ID.""" + return self._data.get("organizationId", "") + + @property + def organization_role(self) -> Optional[Union[str, OrganizationRoleContract]]: + """Get organization role (can be string or dict depending on API response).""" + return self._data.get("organizationRole") + + @property + def phone(self) -> Optional[str]: + """Get phone number.""" + return self._data.get("phone") + + @property + def last_seen_at(self) -> Optional[str]: + """Get last seen timestamp.""" + return self._data.get("lastSeenAt") + + @property + def display_name(self) -> str: + """Get a human-readable display name for the user. + + Returns the name if available, otherwise the email. + """ + return self.name or self.email or self.id + + @property + def full_name(self) -> str: + """Get full name from firstname and lastname. + + Returns: + Full name or falls back to name/email if not available + """ + if self.firstname or self.lastname: + return f"{self.firstname} {self.lastname}".strip() + return self.name or self.email + + @property + def is_admin(self) -> bool: + """Check if user is an organization admin.""" + role = self.organization_role + if not role: + return False + # Handle both string and dict formats + # API can return either "ADMIN" (string) or {"role": "ADMIN", ...} (dict) + if isinstance(role, str): + return role == "ADMIN" + if isinstance(role, dict): + return role.get("role") == "ADMIN" + return False + + def to_dict(self) -> UserContract: + """Get the underlying dictionary representation.""" + return self._data + + +def sort_users_by_email(users: List[UserContract], reverse: bool = False) -> List[UserContract]: + """Sort users by email address. + + Args: + users: List of user contracts to sort + reverse: If True, sort in descending order + + Returns: + Sorted list of users + """ + return sorted( + users, + key=lambda user: user.get("email", ""), + reverse=reverse, + ) + + +def filter_users_by_activated( + users: List[UserContract], activated: bool = True +) -> List[UserContract]: + """Filter users by activation status. + + Args: + users: List of user contracts to filter + activated: If True, return only activated users; if False, only deactivated + + Returns: + Filtered list of users + """ + return [user for user in users if user.get("activated") == activated] diff --git a/src/kili/use_cases_v2/__init__.py b/src/kili/use_cases_v2/__init__.py new file mode 100644 index 000000000..9cda26ee5 --- /dev/null +++ b/src/kili/use_cases_v2/__init__.py @@ -0,0 +1 @@ +"""Use Cases V2 - Business logic with TypedDict-based contracts.""" diff --git a/src/kili/use_cases_v2/interfaces.py b/src/kili/use_cases_v2/interfaces.py new file mode 100644 index 000000000..23a50ea70 --- /dev/null +++ b/src/kili/use_cases_v2/interfaces.py @@ -0,0 +1,577 @@ +"""Repository interface definitions using Protocol for dependency inversion. + +This module defines Protocol-based interfaces for repository patterns, +enabling dependency inversion principle in use cases. Use cases depend on +these interfaces rather than concrete implementations. + +The interfaces use TypedDict contracts from domain_v2 for type-safe returns. +""" + +# pylint: disable=unnecessary-ellipsis,redundant-returns-doc,redundant-yields-doc +# Ellipsis (...) is required for Protocol method definitions +# Docstring returns/yields sections are helpful for interface documentation + +from typing import Generator, List, Literal, Optional, Protocol + +from kili.domain_v2.asset import AssetContract +from kili.domain_v2.label import LabelContract +from kili.domain_v2.project import ProjectContract +from kili.domain_v2.user import UserContract + + +class PaginationParams: + """Pagination parameters for repository queries. + + Attributes: + skip: Number of items to skip + first: Maximum number of items to return (None for all) + batch_size: Size of batches for paginated queries + """ + + def __init__( + self, + skip: int = 0, + first: Optional[int] = None, + batch_size: int = 100, + ) -> None: + """Initialize pagination parameters.""" + self.skip = skip + self.first = first + self.batch_size = batch_size + + +SortOrder = Literal["asc", "desc"] + + +# Asset Repository Interface + + +class IAssetRepository(Protocol): + """Protocol defining the interface for Asset repository operations. + + This interface provides methods for querying and manipulating assets, + returning validated AssetContract objects. + """ + + def get_by_id( + self, + asset_id: str, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[AssetContract]: + """Get a single asset by ID. + + Args: + asset_id: The asset ID + project_id: The project ID containing the asset + fields: Optional list of fields to retrieve + + Returns: + AssetContract if found, None otherwise + """ + ... + + def get_by_external_id( + self, + external_id: str, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[AssetContract]: + """Get a single asset by external ID. + + Args: + external_id: The external ID + project_id: The project ID containing the asset + fields: Optional list of fields to retrieve + + Returns: + AssetContract if found, None otherwise + """ + ... + + def list( + self, + project_id: str, + fields: Optional[List[str]] = None, + status_in: Optional[List[str]] = None, + external_id_in: Optional[List[str]] = None, + asset_id_in: Optional[List[str]] = None, + metadata_where: Optional[dict] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[AssetContract, None, None]: + """List assets matching the given filters. + + Args: + project_id: The project ID + fields: Optional list of fields to retrieve + status_in: Filter by asset status + external_id_in: Filter by external IDs + asset_id_in: Filter by asset IDs + metadata_where: Filter by JSON metadata + created_at_gte: Filter by creation date (greater than or equal) + created_at_lte: Filter by creation date (less than or equal) + pagination: Pagination parameters + + Yields: + AssetContract objects matching the filters + """ + ... + + def count( + self, + project_id: str, + status_in: Optional[List[str]] = None, + external_id_in: Optional[List[str]] = None, + metadata_where: Optional[dict] = None, + ) -> int: + """Count assets matching the given filters. + + Args: + project_id: The project ID + status_in: Filter by asset status + external_id_in: Filter by external IDs + metadata_where: Filter by JSON metadata + + Returns: + Number of assets matching the filters + """ + ... + + def create( + self, + project_id: str, + content: str, + external_id: str, + json_metadata: Optional[dict] = None, + ) -> AssetContract: + """Create a new asset. + + Args: + project_id: The project ID + content: Asset content (URL or text) + external_id: External identifier + json_metadata: Optional metadata dictionary + + Returns: + The created AssetContract + """ + ... + + def update_metadata( + self, + asset_id: str, + json_metadata: dict, + ) -> AssetContract: + """Update asset metadata. + + Args: + asset_id: The asset ID + json_metadata: New metadata dictionary + + Returns: + The updated AssetContract + """ + ... + + def delete( + self, + asset_ids: List[str], + ) -> int: + """Delete assets by IDs. + + Args: + asset_ids: List of asset IDs to delete + + Returns: + Number of assets deleted + """ + ... + + +# Label Repository Interface + + +class ILabelRepository(Protocol): + """Protocol defining the interface for Label repository operations. + + This interface provides methods for querying and manipulating labels, + returning validated LabelContract objects. + """ + + def get_by_id( + self, + label_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[LabelContract]: + """Get a single label by ID. + + Args: + label_id: The label ID + fields: Optional list of fields to retrieve + + Returns: + LabelContract if found, None otherwise + """ + ... + + def list( + self, + asset_id: Optional[str] = None, + project_id: Optional[str] = None, + fields: Optional[List[str]] = None, + label_type_in: Optional[List[str]] = None, + author_in: Optional[List[str]] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[LabelContract, None, None]: + """List labels matching the given filters. + + Args: + asset_id: Filter by asset ID + project_id: Filter by project ID + fields: Optional list of fields to retrieve + label_type_in: Filter by label type + author_in: Filter by author IDs + created_at_gte: Filter by creation date (greater than or equal) + created_at_lte: Filter by creation date (less than or equal) + pagination: Pagination parameters + + Yields: + LabelContract objects matching the filters + """ + ... + + def count( + self, + asset_id: Optional[str] = None, + project_id: Optional[str] = None, + label_type_in: Optional[List[str]] = None, + author_in: Optional[List[str]] = None, + ) -> int: + """Count labels matching the given filters. + + Args: + asset_id: Filter by asset ID + project_id: Filter by project ID + label_type_in: Filter by label type + author_in: Filter by author IDs + + Returns: + Number of labels matching the filters + """ + ... + + def create( + self, + asset_id: str, + json_response: dict, + label_type: str = "DEFAULT", + seconds_to_label: Optional[int] = None, + ) -> LabelContract: + """Create a new label. + + Args: + asset_id: The asset ID + json_response: Label annotation data + label_type: Type of label (DEFAULT, REVIEW, etc.) + seconds_to_label: Time spent labeling in seconds + + Returns: + The created LabelContract + """ + ... + + def update( + self, + label_id: str, + json_response: dict, + ) -> LabelContract: + """Update an existing label. + + Args: + label_id: The label ID + json_response: Updated annotation data + + Returns: + The updated LabelContract + """ + ... + + def delete( + self, + label_ids: List[str], + ) -> int: + """Delete labels by IDs. + + Args: + label_ids: List of label IDs to delete + + Returns: + Number of labels deleted + """ + ... + + +# Project Repository Interface + + +class IProjectRepository(Protocol): + """Protocol defining the interface for Project repository operations. + + This interface provides methods for querying and manipulating projects, + returning validated ProjectContract objects. + """ + + def get_by_id( + self, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[ProjectContract]: + """Get a single project by ID. + + Args: + project_id: The project ID + fields: Optional list of fields to retrieve + + Returns: + ProjectContract if found, None otherwise + """ + ... + + def list( + self, + fields: Optional[List[str]] = None, + archived: Optional[bool] = None, + starred: Optional[bool] = None, + input_type_in: Optional[List[str]] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[ProjectContract, None, None]: + """List projects matching the given filters. + + Args: + fields: Optional list of fields to retrieve + archived: Filter by archived status + starred: Filter by starred status + input_type_in: Filter by input types (IMAGE, TEXT, etc.) + created_at_gte: Filter by creation date (greater than or equal) + created_at_lte: Filter by creation date (less than or equal) + pagination: Pagination parameters + + Yields: + ProjectContract objects matching the filters + """ + ... + + def count( + self, + archived: Optional[bool] = None, + starred: Optional[bool] = None, + input_type_in: Optional[List[str]] = None, + ) -> int: + """Count projects matching the given filters. + + Args: + archived: Filter by archived status + starred: Filter by starred status + input_type_in: Filter by input types + + Returns: + Number of projects matching the filters + """ + ... + + def create( + self, + title: str, + description: str, + input_type: str, + json_interface: dict, + ) -> ProjectContract: + """Create a new project. + + Args: + title: Project title + description: Project description + input_type: Type of input data (IMAGE, TEXT, etc.) + json_interface: Ontology/interface definition + + Returns: + The created ProjectContract + """ + ... + + def update( + self, + project_id: str, + title: Optional[str] = None, + description: Optional[str] = None, + json_interface: Optional[dict] = None, + ) -> ProjectContract: + """Update an existing project. + + Args: + project_id: The project ID + title: Optional new title + description: Optional new description + json_interface: Optional new interface definition + + Returns: + The updated ProjectContract + """ + ... + + def archive( + self, + project_id: str, + ) -> ProjectContract: + """Archive a project. + + Args: + project_id: The project ID + + Returns: + The archived ProjectContract + """ + ... + + def delete( + self, + project_ids: List[str], + ) -> int: + """Delete projects by IDs. + + Args: + project_ids: List of project IDs to delete + + Returns: + Number of projects deleted + """ + ... + + +# User Repository Interface + + +class IUserRepository(Protocol): + """Protocol defining the interface for User repository operations. + + This interface provides methods for querying and manipulating users, + returning validated UserContract objects. + """ + + def get_by_id( + self, + user_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[UserContract]: + """Get a single user by ID. + + Args: + user_id: The user ID + fields: Optional list of fields to retrieve + + Returns: + UserContract if found, None otherwise + """ + ... + + def get_by_email( + self, + email: str, + fields: Optional[List[str]] = None, + ) -> Optional[UserContract]: + """Get a single user by email. + + Args: + email: The user email + fields: Optional list of fields to retrieve + + Returns: + UserContract if found, None otherwise + """ + ... + + def list( + self, + organization_id: str, + fields: Optional[List[str]] = None, + activated: Optional[bool] = None, + email_contains: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[UserContract, None, None]: + """List users matching the given filters. + + Args: + organization_id: The organization ID + fields: Optional list of fields to retrieve + activated: Filter by activation status + email_contains: Filter by email substring + pagination: Pagination parameters + + Yields: + UserContract objects matching the filters + """ + ... + + def count( + self, + organization_id: str, + activated: Optional[bool] = None, + ) -> int: + """Count users matching the given filters. + + Args: + organization_id: The organization ID + activated: Filter by activation status + + Returns: + Number of users matching the filters + """ + ... + + def create( + self, + organization_id: str, + email: str, + firstname: str, + lastname: str, + role: str = "USER", + ) -> UserContract: + """Create a new user. + + Args: + organization_id: The organization ID + email: User email address + firstname: User first name + lastname: User last name + role: Organization role (ADMIN, USER, REVIEWER) + + Returns: + The created UserContract + """ + ... + + def update( + self, + user_id: str, + firstname: Optional[str] = None, + lastname: Optional[str] = None, + activated: Optional[bool] = None, + ) -> UserContract: + """Update an existing user. + + Args: + user_id: The user ID + firstname: Optional new first name + lastname: Optional new last name + activated: Optional new activation status + + Returns: + The updated UserContract + """ + ... diff --git a/tests/equivalence/__init__.py b/tests/equivalence/__init__.py new file mode 100644 index 000000000..e7bd1bee4 --- /dev/null +++ b/tests/equivalence/__init__.py @@ -0,0 +1,5 @@ +"""Equivalence testing framework for validating legacy vs v2 API compatibility. + +This package provides tools for automated testing to ensure semantic equivalence +between legacy (mixin-based) client methods and v2 (domain-based) implementations. +""" diff --git a/tests/equivalence/comparator.py b/tests/equivalence/comparator.py new file mode 100644 index 000000000..9469f6b86 --- /dev/null +++ b/tests/equivalence/comparator.py @@ -0,0 +1,305 @@ +"""Response comparison utilities for equivalence testing. + +This module provides sophisticated comparison capabilities to verify +semantic equivalence between legacy and v2 API responses. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional + +from .normalizer import DiffGenerator, PayloadNormalizer + + +class ComparisonStatus(Enum): + """Status of a comparison.""" + + EQUIVALENT = "equivalent" + DIFFERENT = "different" + ERROR = "error" + SKIPPED = "skipped" + + +@dataclass +class ComparisonResult: + """Result of comparing legacy and v2 responses.""" + + status: ComparisonStatus + legacy_response: Any + v2_response: Any + differences: List[str] = field(default_factory=list) + error_message: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def is_equivalent(self) -> bool: + """Check if responses are equivalent.""" + return self.status == ComparisonStatus.EQUIVALENT + + @property + def has_differences(self) -> bool: + """Check if responses have differences.""" + return len(self.differences) > 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "status": self.status.value, + "legacy_response": self.legacy_response, + "v2_response": self.v2_response, + "differences": self.differences, + "error_message": self.error_message, + "metadata": self.metadata, + } + + +class ResponseComparator: + """Compare responses from legacy and v2 implementations. + + This class provides sophisticated comparison logic that goes beyond + simple equality to check semantic equivalence. + """ + + def __init__( + self, + normalizer: Optional[PayloadNormalizer] = None, + diff_generator: Optional[DiffGenerator] = None, + custom_comparators: Optional[Dict[str, Callable]] = None, + ): + """Initialize the comparator. + + Args: + normalizer: Payload normalizer to use (default: PayloadNormalizer) + diff_generator: Diff generator to use (default: DiffGenerator) + custom_comparators: Custom comparison functions by method name + """ + self.normalizer = normalizer or PayloadNormalizer() + self.diff_generator = diff_generator or DiffGenerator() + self.custom_comparators = custom_comparators or {} + + def compare( + self, + legacy_response: Any, + v2_response: Any, + method_name: Optional[str] = None, + strict: bool = False, + ) -> ComparisonResult: + """Compare legacy and v2 responses. + + Args: + legacy_response: Response from legacy implementation + v2_response: Response from v2 implementation + method_name: Name of the method (for custom comparison) + strict: Whether to use strict comparison (no normalization) + + Returns: + ComparisonResult with details + """ + # Check for custom comparator + if method_name and method_name in self.custom_comparators: + return self.custom_comparators[method_name](legacy_response, v2_response) + + try: + # Normalize responses if not strict + if not strict: + legacy_norm, v2_norm = self.normalizer.normalize_for_comparison( + legacy_response, v2_response + ) + else: + legacy_norm, v2_norm = legacy_response, v2_response + + # Compare + if legacy_norm == v2_norm: + return ComparisonResult( + status=ComparisonStatus.EQUIVALENT, + legacy_response=legacy_response, + v2_response=v2_response, + metadata={"normalized": not strict}, + ) + + # Generate differences + differences = self.diff_generator.generate_diff(legacy_norm, v2_norm) + + return ComparisonResult( + status=ComparisonStatus.DIFFERENT, + legacy_response=legacy_response, + v2_response=v2_response, + differences=differences, + metadata={"normalized": not strict}, + ) + + except Exception as e: # pylint: disable=broad-except + return ComparisonResult( + status=ComparisonStatus.ERROR, + legacy_response=legacy_response, + v2_response=v2_response, + error_message=str(e), + ) + + def compare_batch( + self, + pairs: List[tuple[Any, Any]], + method_name: Optional[str] = None, + strict: bool = False, + ) -> List[ComparisonResult]: + """Compare multiple pairs of responses. + + Args: + pairs: List of (legacy_response, v2_response) tuples + method_name: Name of the method + strict: Whether to use strict comparison + + Returns: + List of ComparisonResults + """ + return [self.compare(legacy, v2, method_name, strict) for legacy, v2 in pairs] + + def register_custom_comparator( + self, + method_name: str, + comparator: Callable[[Any, Any], ComparisonResult], + ) -> None: + """Register a custom comparison function for a method. + + Args: + method_name: Name of the method + comparator: Function that takes (legacy, v2) and returns ComparisonResult + """ + self.custom_comparators[method_name] = comparator + + +class EquivalenceAssertion: + """Assertion utilities for equivalence testing.""" + + @staticmethod + def assert_equivalent( + result: ComparisonResult, + message: Optional[str] = None, + ) -> None: + """Assert that responses are equivalent. + + Args: + result: Comparison result to check + message: Optional custom message + + Raises: + AssertionError: If responses are not equivalent + """ + if result.status == ComparisonStatus.ERROR: + raise AssertionError(f"{message or 'Comparison failed'}: {result.error_message}") + + if not result.is_equivalent: + diff_str = "\n".join(result.differences) + raise AssertionError(f"{message or 'Responses not equivalent'}:\n{diff_str}") + + @staticmethod + def assert_batch_equivalent( + results: List[ComparisonResult], + message: Optional[str] = None, + ) -> None: + """Assert that all batch comparisons are equivalent. + + Args: + results: List of comparison results + message: Optional custom message + + Raises: + AssertionError: If any comparison is not equivalent + """ + failed = [r for r in results if not r.is_equivalent] + + if failed: + errors = [f"Result {i}: {r.differences}" for i, r in enumerate(failed)] + raise AssertionError( + f"{message or 'Batch comparison failed'} " + f"({len(failed)}/{len(results)} failed):\n" + "\n".join(errors) + ) + + +# Example custom comparators + + +def create_pagination_comparator() -> Callable[[Any, Any], ComparisonResult]: + """Create a custom comparator for paginated responses. + + Paginated responses may differ in structure but should contain + the same data when all pages are aggregated. + """ + + def compare_paginated(legacy: Any, v2: Any) -> ComparisonResult: + """Compare paginated responses.""" + # This is a simplified example - real implementation would + # aggregate pages and compare the full dataset + try: + # Assume both are generators/lists of items + legacy_items = list(legacy) if hasattr(legacy, "__iter__") else [legacy] + v2_items = list(v2) if hasattr(v2, "__iter__") else [v2] + + if len(legacy_items) != len(v2_items): + return ComparisonResult( + status=ComparisonStatus.DIFFERENT, + legacy_response=legacy, + v2_response=v2, + differences=[ + f"Item count mismatch: legacy={len(legacy_items)}, v2={len(v2_items)}" + ], + ) + + # Sort by ID for comparison + legacy_sorted = sorted( + legacy_items, key=lambda x: x.get("id", "") if isinstance(x, dict) else str(x) + ) + v2_sorted = sorted( + v2_items, key=lambda x: x.get("id", "") if isinstance(x, dict) else str(x) + ) + + # Use standard comparison + comparator = ResponseComparator() + return comparator.compare(legacy_sorted, v2_sorted) + + except Exception as e: # pylint: disable=broad-except + return ComparisonResult( + status=ComparisonStatus.ERROR, + legacy_response=legacy, + v2_response=v2, + error_message=f"Pagination comparison error: {e}", + ) + + return compare_paginated + + +def create_count_comparator() -> Callable[[Any, Any], ComparisonResult]: + """Create a custom comparator for count methods. + + Count methods should return the same integer value. + """ + + def compare_counts(legacy: Any, v2: Any) -> ComparisonResult: + """Compare count responses.""" + try: + legacy_count = int(legacy) + v2_count = int(v2) + + if legacy_count == v2_count: + return ComparisonResult( + status=ComparisonStatus.EQUIVALENT, + legacy_response=legacy, + v2_response=v2, + ) + + return ComparisonResult( + status=ComparisonStatus.DIFFERENT, + legacy_response=legacy, + v2_response=v2, + differences=[f"Count mismatch: legacy={legacy_count}, v2={v2_count}"], + ) + + except (TypeError, ValueError) as e: + return ComparisonResult( + status=ComparisonStatus.ERROR, + legacy_response=legacy, + v2_response=v2, + error_message=f"Invalid count value: {e}", + ) + + return compare_counts diff --git a/tests/equivalence/conftest.py b/tests/equivalence/conftest.py new file mode 100644 index 000000000..84129469e --- /dev/null +++ b/tests/equivalence/conftest.py @@ -0,0 +1,54 @@ +"""Pytest configuration for equivalence tests. + +This module provides fixtures and configuration for running equivalence tests. +""" + +import os + +import pytest + + +def pytest_configure(config): + """Configure pytest for equivalence tests.""" + config.addinivalue_line("markers", "equivalence: mark test as an equivalence test") + config.addinivalue_line("markers", "slow: mark test as slow running") + config.addinivalue_line( + "markers", "integration: mark test as integration test requiring API access" + ) + + +@pytest.fixture(scope="session") +def recordings_dir(tmp_path_factory): + """Provide a temporary directory for test recordings.""" + return tmp_path_factory.mktemp("recordings") + + +@pytest.fixture() +def use_real_api(): + """Check if tests should use real API (based on env var).""" + return os.getenv("KILI_USE_REAL_API", "false").lower() == "true" + + +@pytest.fixture() +def api_key(): + """Get API key from environment.""" + return os.getenv("KILI_API_KEY") + + +@pytest.fixture() +def api_endpoint(): + """Get API endpoint from environment.""" + return os.getenv("KILI_API_ENDPOINT", "https://cloud.kili-technology.com/api/label/v2/graphql") + + +@pytest.fixture() +def test_project_id(): + """Get test project ID from environment.""" + return os.getenv("KILI_TEST_PROJECT_ID") + + +@pytest.fixture() +def skip_if_no_api(use_real_api, api_key): # pylint: disable=redefined-outer-name + """Skip test if API access is not configured.""" + if not use_real_api or not api_key: + pytest.skip("Skipping test: API access not configured") diff --git a/tests/equivalence/examples/compare_responses.py b/tests/equivalence/examples/compare_responses.py new file mode 100644 index 000000000..35e37fcea --- /dev/null +++ b/tests/equivalence/examples/compare_responses.py @@ -0,0 +1,268 @@ +"""Example script for comparing legacy and v2 responses. + +This script demonstrates how to compare responses from legacy and v2 +implementations to verify semantic equivalence. +""" + +import json + +# Add parent directory to path for imports +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from comparator import EquivalenceAssertion, ResponseComparator # noqa: E402 +from normalizer import DiffGenerator, PayloadNormalizer # noqa: E402 + + +def example_basic_comparison(): + """Example: Basic response comparison.""" + print("=" * 60) + print("Example 1: Basic Response Comparison") + print("=" * 60) + + # Legacy response (dict) + legacy_response = { + "id": "asset_123", + "externalId": "image_001", + "content": "https://example.com/image.jpg", + "status": "TODO", + "__typename": "Asset", # GraphQL metadata + } + + # V2 response (TypedDict as dict) + v2_response = { + "id": "asset_123", + "externalId": "image_001", + "content": "https://example.com/image.jpg", + "status": "TODO", + } + + # Compare + comparator = ResponseComparator() + result = comparator.compare(legacy_response, v2_response) + + print(f"\nLegacy: {json.dumps(legacy_response, indent=2)}") + print(f"\nV2: {json.dumps(v2_response, indent=2)}") + print(f"\nEquivalent: {result.is_equivalent}") + + if result.is_equivalent: + print("โœ“ Responses are equivalent!") + else: + print("โœ— Responses differ:") + for diff in result.differences: + print(f" - {diff}") + + +def example_list_comparison(): + """Example: List response comparison with sorting.""" + print("\n" + "=" * 60) + print("Example 2: List Response Comparison") + print("=" * 60) + + # Legacy response (unordered) + legacy_response = [ + {"id": "2", "name": "Asset B"}, + {"id": "1", "name": "Asset A"}, + {"id": "3", "name": "Asset C"}, + ] + + # V2 response (different order) + v2_response = [ + {"id": "1", "name": "Asset A"}, + {"id": "3", "name": "Asset C"}, + {"id": "2", "name": "Asset B"}, + ] + + # Compare + comparator = ResponseComparator() + result = comparator.compare(legacy_response, v2_response) + + print(f"\nLegacy order: {[a['id'] for a in legacy_response]}") + print(f"V2 order: {[a['id'] for a in v2_response]}") + print(f"\nEquivalent: {result.is_equivalent}") + + if result.is_equivalent: + print("โœ“ Lists are equivalent (after normalization)!") + + +def example_nested_comparison(): + """Example: Nested object comparison.""" + print("\n" + "=" * 60) + print("Example 3: Nested Object Comparison") + print("=" * 60) + + # Legacy response + legacy_response = { + "id": "asset_123", + "labels": [ + { + "id": "label_1", + "author": { + "id": "user_1", + "email": "user@example.com", + "__typename": "User", + }, + } + ], + } + + # V2 response + v2_response = { + "id": "asset_123", + "labels": [ + { + "id": "label_1", + "author": { + "id": "user_1", + "email": "user@example.com", + }, + } + ], + } + + # Compare + comparator = ResponseComparator() + result = comparator.compare(legacy_response, v2_response) + + print(f"\nEquivalent: {result.is_equivalent}") + + if result.is_equivalent: + print("โœ“ Nested structures are equivalent!") + + +def example_difference_detection(): + """Example: Detecting and displaying differences.""" + print("\n" + "=" * 60) + print("Example 4: Difference Detection") + print("=" * 60) + + # Legacy response + legacy_response = { + "id": "asset_123", + "status": "TODO", + "metadata": {"camera": "drone", "altitude": 100}, + } + + # V2 response (different metadata) + v2_response = { + "id": "asset_123", + "status": "LABELED", # Different status + "metadata": {"camera": "drone", "altitude": 150}, # Different altitude + } + + # Compare + comparator = ResponseComparator() + result = comparator.compare(legacy_response, v2_response) + + print(f"\nEquivalent: {result.is_equivalent}") + + if not result.is_equivalent: + print("\nDifferences found:") + for diff in result.differences: + print(f" - {diff}") + + +def example_count_comparison(): + """Example: Comparing count responses.""" + print("\n" + "=" * 60) + print("Example 5: Count Response Comparison") + print("=" * 60) + + from comparator import create_count_comparator + + # Legacy response + legacy_count = 42 + + # V2 response + v2_count = 42 + + # Use count comparator + count_comparator = create_count_comparator() + result = count_comparator(legacy_count, v2_count) + + print(f"\nLegacy count: {legacy_count}") + print(f"V2 count: {v2_count}") + print(f"Equivalent: {result.is_equivalent}") + + if result.is_equivalent: + print("โœ“ Counts match!") + + +def example_assertion(): + """Example: Using assertions in tests.""" + print("\n" + "=" * 60) + print("Example 6: Using Equivalence Assertions") + print("=" * 60) + + legacy_response = {"id": "123", "name": "test"} + v2_response = {"id": "123", "name": "test"} + + comparator = ResponseComparator() + result = comparator.compare(legacy_response, v2_response) + + try: + EquivalenceAssertion.assert_equivalent(result, message="Asset responses must be equivalent") + print("\nโœ“ Assertion passed!") + except AssertionError as e: + print(f"\nโœ— Assertion failed: {e}") + + +def example_manual_normalization(): + """Example: Manual normalization and diff generation.""" + print("\n" + "=" * 60) + print("Example 7: Manual Normalization") + print("=" * 60) + + legacy = { + "id": "123", + "items": [{"id": "2"}, {"id": "1"}], + "__typename": "Response", + } + + v2 = { + "id": "123", + "items": [{"id": "1"}, {"id": "2"}], + } + + # Normalize + normalizer = PayloadNormalizer() + norm_legacy, norm_v2 = normalizer.normalize_for_comparison(legacy, v2) + + print("\nOriginal legacy:") + print(json.dumps(legacy, indent=2)) + print("\nNormalized legacy:") + print(json.dumps(norm_legacy, indent=2)) + + print(f"\nEquivalent after normalization: {norm_legacy == norm_v2}") + + if norm_legacy != norm_v2: + # Generate diff + diff_gen = DiffGenerator() + diffs = diff_gen.generate_diff(norm_legacy, norm_v2) + print("\nDifferences:") + for diff in diffs: + print(f" - {diff}") + + +def main(): + """Run all examples.""" + print("\nEquivalence Testing Examples") + print("=" * 60) + + example_basic_comparison() + example_list_comparison() + example_nested_comparison() + example_difference_detection() + example_count_comparison() + example_assertion() + example_manual_normalization() + + print("\n" + "=" * 60) + print("All examples completed!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/equivalence/examples/record_legacy_requests.py b/tests/equivalence/examples/record_legacy_requests.py new file mode 100644 index 000000000..b7e9861f9 --- /dev/null +++ b/tests/equivalence/examples/record_legacy_requests.py @@ -0,0 +1,162 @@ +"""Example script for recording legacy API requests. + +This script demonstrates how to record requests from the legacy Kili client +for later replay against v2 implementations. +""" + +import os + +# Add parent directory to path for imports +import sys +from pathlib import Path + +from kili.client import Kili + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from recorder import RequestRecorder # noqa: E402 + + +def main(): + """Record sample legacy API requests.""" + # Initialize legacy client + api_key = os.getenv("KILI_API_KEY") + if not api_key: + print("Error: KILI_API_KEY environment variable not set") + return + + kili = Kili(api_key=api_key) + + # Initialize recorder + recordings_dir = Path(__file__).parent / "recordings" + recorder = RequestRecorder(storage_dir=recordings_dir) + + # Get a test project (use your actual project ID) + project_id = os.getenv("KILI_TEST_PROJECT_ID") + if not project_id: + # Try to get first project + projects = list(kili.projects(first=1)) + if projects: + project_id = projects[0]["id"] + else: + print("No projects found. Please set KILI_TEST_PROJECT_ID") + return + + print(f"Recording requests for project: {project_id}") + + # Record: count_assets + print("\nRecording: count_assets") + try: + response = kili.count_assets(project_id=project_id) + recorder.record( + method_name="count_assets", + kwargs={"project_id": project_id}, + response=response, + context={"project_id": project_id}, + ) + print(f" โœ“ Recorded: {response} assets") + except Exception as e: # pylint: disable=broad-except + recorder.record( + method_name="count_assets", + kwargs={"project_id": project_id}, + exception=e, + context={"project_id": project_id}, + ) + print(f" โœ— Error: {e}") + + # Record: assets (list) + print("\nRecording: assets (first 10)") + try: + response = list(kili.assets(project_id=project_id, first=10)) + recorder.record( + method_name="assets", + kwargs={"project_id": project_id, "first": 10}, + response=response, + context={"project_id": project_id}, + ) + print(f" โœ“ Recorded: {len(response)} assets") + except Exception as e: # pylint: disable=broad-except + recorder.record( + method_name="assets", + kwargs={"project_id": project_id, "first": 10}, + exception=e, + context={"project_id": project_id}, + ) + print(f" โœ— Error: {e}") + + # Record: assets with pagination + print("\nRecording: assets with pagination (skip 5, first 5)") + try: + response = list(kili.assets(project_id=project_id, skip=5, first=5)) + recorder.record( + method_name="assets", + kwargs={"project_id": project_id, "skip": 5, "first": 5}, + response=response, + context={"project_id": project_id}, + ) + print(f" โœ“ Recorded: {len(response)} assets") + except Exception as e: # pylint: disable=broad-except + recorder.record( + method_name="assets", + kwargs={"project_id": project_id, "skip": 5, "first": 5}, + exception=e, + context={"project_id": project_id}, + ) + print(f" โœ— Error: {e}") + + # Record: count_labels + print("\nRecording: count_labels") + try: + response = kili.count_labels(project_id=project_id) + recorder.record( + method_name="count_labels", + kwargs={"project_id": project_id}, + response=response, + context={"project_id": project_id}, + ) + print(f" โœ“ Recorded: {response} labels") + except Exception as e: # pylint: disable=broad-except + recorder.record( + method_name="count_labels", + kwargs={"project_id": project_id}, + exception=e, + context={"project_id": project_id}, + ) + print(f" โœ— Error: {e}") + + # Record: projects (list) + print("\nRecording: projects (first 5)") + try: + response = list(kili.projects(first=5)) + recorder.record( + method_name="projects", + kwargs={"first": 5}, + response=response, + ) + print(f" โœ“ Recorded: {len(response)} projects") + except Exception as e: # pylint: disable=broad-except + recorder.record( + method_name="projects", + kwargs={"first": 5}, + exception=e, + ) + print(f" โœ— Error: {e}") + + # Save recordings + print("\nSaving recordings...") + filepath = recorder.save("legacy_api_requests", format="json") + print(f" โœ“ Saved to: {filepath}") + + # Print summary + print("\nRecording Summary:") + summary = recorder.get_summary() + print(f" Total recordings: {summary['total_recordings']}") + print(" Methods recorded:") + for method, stats in summary["methods"].items(): + print( + f" - {method}: {stats['count']} calls ({stats['success']} success, {stats['errors']} errors)" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/equivalence/fixtures.py b/tests/equivalence/fixtures.py new file mode 100644 index 000000000..0bc1409b4 --- /dev/null +++ b/tests/equivalence/fixtures.py @@ -0,0 +1,257 @@ +"""Test fixtures for equivalence testing. + +This module provides common test fixtures and data for equivalence tests. +""" + +from typing import List + +from .harness import TestCase + +# Asset-related test cases +ASSET_TEST_CASES: List[TestCase] = [ + TestCase( + name="count_assets_basic", + method_name="count_assets", + legacy_method_path="count_assets", + v2_method_path="assets.count", + kwargs={"project_id": "test_project_id"}, + description="Count all assets in a project", + ), + TestCase( + name="count_assets_with_status_filter", + method_name="count_assets", + legacy_method_path="count_assets", + v2_method_path="assets.count", + kwargs={ + "project_id": "test_project_id", + "status_in": ["TODO", "ONGOING"], + }, + description="Count assets filtered by status", + ), + TestCase( + name="assets_list_basic", + method_name="assets", + legacy_method_path="assets", + v2_method_path="assets.list", + kwargs={ + "project_id": "test_project_id", + "first": 10, + }, + description="List first 10 assets", + ), + TestCase( + name="assets_list_with_pagination", + method_name="assets", + legacy_method_path="assets", + v2_method_path="assets.list", + kwargs={ + "project_id": "test_project_id", + "first": 50, + "skip": 100, + }, + description="List assets with pagination", + ), + TestCase( + name="assets_list_with_external_id_filter", + method_name="assets", + legacy_method_path="assets", + v2_method_path="assets.list", + kwargs={ + "project_id": "test_project_id", + "external_id_contains": ["image", "photo"], + }, + description="List assets filtered by external ID", + ), + TestCase( + name="assets_list_with_metadata_filter", + method_name="assets", + legacy_method_path="assets", + v2_method_path="assets.list", + kwargs={ + "project_id": "test_project_id", + "metadata_where": {"camera": "drone"}, + }, + description="List assets filtered by metadata", + ), +] + + +# Label-related test cases +LABEL_TEST_CASES: List[TestCase] = [ + TestCase( + name="count_labels_basic", + method_name="count_labels", + legacy_method_path="count_labels", + v2_method_path="labels.count", + kwargs={"project_id": "test_project_id"}, + description="Count all labels in a project", + ), + TestCase( + name="count_labels_by_asset", + method_name="count_labels", + legacy_method_path="count_labels", + v2_method_path="labels.count", + kwargs={ + "project_id": "test_project_id", + "asset_id": "test_asset_id", + }, + description="Count labels for a specific asset", + ), + TestCase( + name="labels_list_basic", + method_name="labels", + legacy_method_path="labels", + v2_method_path="labels.list", + kwargs={ + "project_id": "test_project_id", + "first": 10, + }, + description="List first 10 labels", + ), + TestCase( + name="labels_list_by_author", + method_name="labels", + legacy_method_path="labels", + v2_method_path="labels.list", + kwargs={ + "project_id": "test_project_id", + "author_in": ["user1@example.com"], + }, + description="List labels filtered by author", + ), +] + + +# Project-related test cases +PROJECT_TEST_CASES: List[TestCase] = [ + TestCase( + name="count_projects_basic", + method_name="count_projects", + legacy_method_path="count_projects", + v2_method_path="projects.count", + kwargs={}, + description="Count all projects", + ), + TestCase( + name="count_projects_archived", + method_name="count_projects", + legacy_method_path="count_projects", + v2_method_path="projects.count", + kwargs={"archived": True}, + description="Count archived projects", + ), + TestCase( + name="projects_list_basic", + method_name="projects", + legacy_method_path="projects", + v2_method_path="projects.list", + kwargs={"first": 10}, + description="List first 10 projects", + ), + TestCase( + name="projects_list_by_input_type", + method_name="projects", + legacy_method_path="projects", + v2_method_path="projects.list", + kwargs={ + "input_type_in": ["IMAGE", "VIDEO"], + }, + description="List projects filtered by input type", + ), +] + + +# User-related test cases +USER_TEST_CASES: List[TestCase] = [ + TestCase( + name="count_users_basic", + method_name="count_users", + legacy_method_path="count_users", + v2_method_path="users.count", + kwargs={"organization_id": "test_org_id"}, + description="Count all users in organization", + ), + TestCase( + name="users_list_basic", + method_name="users", + legacy_method_path="users", + v2_method_path="users.list", + kwargs={ + "organization_id": "test_org_id", + "first": 10, + }, + description="List first 10 users", + ), +] + + +# Error scenario test cases +ERROR_TEST_CASES: List[TestCase] = [ + TestCase( + name="count_assets_invalid_project", + method_name="count_assets", + legacy_method_path="count_assets", + v2_method_path="assets.count", + kwargs={"project_id": "non_existent_project"}, + description="Count assets with invalid project ID (should raise error)", + ), + TestCase( + name="assets_list_invalid_pagination", + method_name="assets", + legacy_method_path="assets", + v2_method_path="assets.list", + kwargs={ + "project_id": "test_project_id", + "first": -1, # Invalid pagination + }, + description="List assets with invalid pagination (should raise error)", + ), +] + + +# Comprehensive test suite combining all test cases +ALL_TEST_CASES = ( + ASSET_TEST_CASES + LABEL_TEST_CASES + PROJECT_TEST_CASES + USER_TEST_CASES + ERROR_TEST_CASES +) + + +def get_crud_test_cases(entity: str) -> List[TestCase]: + """Get CRUD test cases for a specific entity. + + Args: + entity: Entity name ("asset", "label", "project", or "user") + + Returns: + List of test cases for CRUD operations + """ + entity_lower = entity.lower() + + if entity_lower == "asset": + return ASSET_TEST_CASES + if entity_lower == "label": + return LABEL_TEST_CASES + if entity_lower == "project": + return PROJECT_TEST_CASES + if entity_lower == "user": + return USER_TEST_CASES + raise ValueError(f"Unknown entity: {entity}") + + +def get_test_cases_by_category(category: str) -> List[TestCase]: + """Get test cases by category. + + Args: + category: Category name ("crud", "pagination", "filtering", "error") + + Returns: + List of test cases in the category + """ + if category == "crud": + return ALL_TEST_CASES + if category == "pagination": + return [tc for tc in ALL_TEST_CASES if "pagination" in tc.description.lower()] + if category == "filtering": + return [tc for tc in ALL_TEST_CASES if "filter" in tc.description.lower()] + if category == "error": + return ERROR_TEST_CASES + raise ValueError(f"Unknown category: {category}") diff --git a/tests/equivalence/harness.py b/tests/equivalence/harness.py new file mode 100644 index 000000000..c7dfce084 --- /dev/null +++ b/tests/equivalence/harness.py @@ -0,0 +1,359 @@ +"""Test harness for automated equivalence testing. + +This module provides the main test harness that orchestrates recording, +replaying, and comparing legacy vs v2 API implementations. +""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Protocol + +from .comparator import ( + ComparisonResult, + ComparisonStatus, + ResponseComparator, + create_count_comparator, + create_pagination_comparator, +) +from .recorder import RecordedRequest, RequestRecorder + + +class LegacyClient(Protocol): + """Protocol for legacy Kili client.""" + + def __getattr__(self, name: str) -> Callable: + """Get method by name.""" + ... + + +class V2Client(Protocol): + """Protocol for v2 Kili client.""" + + def __getattr__(self, name: str) -> Callable: + """Get namespace/method by name.""" + ... + + +@dataclass +class TestCase: + """A single equivalence test case.""" + + name: str + method_name: str + legacy_method_path: str # e.g., "count_assets" + v2_method_path: str # e.g., "assets.count" + args: tuple = field(default_factory=tuple) + kwargs: Dict[str, Any] = field(default_factory=dict) + description: str = "" + skip: bool = False + skip_reason: str = "" + + +@dataclass +class TestResult: + """Result of running a test case.""" + + test_case: TestCase + comparison_result: Optional[ComparisonResult] = None + error: Optional[str] = None + skipped: bool = False + + @property + def passed(self) -> bool: + """Check if test passed.""" + if self.skipped or self.error: + return False + return self.comparison_result is not None and self.comparison_result.is_equivalent + + @property + def failed(self) -> bool: + """Check if test failed.""" + return not self.passed and not self.skipped + + +class EquivalenceTestHarness: + """Main test harness for equivalence testing. + + This class orchestrates the recording, replay, and comparison of + legacy and v2 implementations. + """ + + def __init__( + self, + legacy_client: Optional[Any] = None, + v2_client: Optional[Any] = None, + recorder: Optional[RequestRecorder] = None, + comparator: Optional[ResponseComparator] = None, + ): + """Initialize the test harness. + + Args: + legacy_client: Instance of legacy Kili client + v2_client: Instance of v2 Kili client + recorder: Request recorder + comparator: Response comparator + """ + self.legacy_client = legacy_client + self.v2_client = v2_client + self.recorder = recorder or RequestRecorder() + self.comparator = comparator or ResponseComparator() + + # Register common custom comparators + self._register_default_comparators() + + def _register_default_comparators(self) -> None: + """Register default custom comparators for common patterns.""" + # Count methods + count_comparator = create_count_comparator() + self.comparator.register_custom_comparator("count_assets", count_comparator) + self.comparator.register_custom_comparator("count_labels", count_comparator) + self.comparator.register_custom_comparator("count_projects", count_comparator) + + # Pagination methods + pagination_comparator = create_pagination_comparator() + self.comparator.register_custom_comparator("assets", pagination_comparator) + self.comparator.register_custom_comparator("labels", pagination_comparator) + self.comparator.register_custom_comparator("projects", pagination_comparator) + + def record_legacy_request( + self, + method_name: str, + args: tuple = (), + kwargs: Optional[Dict[str, Any]] = None, + ) -> RecordedRequest: + """Record a request from legacy client. + + Args: + method_name: Name of the method to call + args: Positional arguments + kwargs: Keyword arguments + + Returns: + Recorded request object + """ + kwargs = kwargs or {} + + # Get the method from legacy client + method = self._get_method(self.legacy_client, method_name) + + # Execute and record + try: + response = method(*args, **kwargs) + exception = None + except Exception as e: # pylint: disable=broad-except + response = None + exception = e + + # Extract context + context = {} + if "project_id" in kwargs: + context["project_id"] = kwargs["project_id"] + if "asset_id" in kwargs: + context["asset_id"] = kwargs["asset_id"] + if "label_id" in kwargs: + context["label_id"] = kwargs["label_id"] + + return self.recorder.record( + method_name=method_name, + args=args, + kwargs=kwargs, + response=response, + exception=exception, + context=context, + ) + + def replay_against_v2( + self, + recording: RecordedRequest, + v2_method_path: str, + ) -> tuple[Any, Optional[Exception]]: + """Replay a recorded request against v2 implementation. + + Args: + recording: The recorded request to replay + v2_method_path: Path to v2 method (e.g., "assets.count") + + Returns: + Tuple of (response, exception) + """ + # Get the v2 method + method = self._get_method(self.v2_client, v2_method_path) + + # Execute + try: + response = method(*recording.args, **recording.kwargs) + return response, None + except Exception as e: # pylint: disable=broad-except + return None, e + + def run_test_case(self, test_case: TestCase) -> TestResult: + """Run a single test case. + + Args: + test_case: The test case to run + + Returns: + Test result + """ + if test_case.skip: + return TestResult( + test_case=test_case, + skipped=True, + ) + + try: + # Execute legacy method + legacy_method = self._get_method(self.legacy_client, test_case.legacy_method_path) + legacy_response = legacy_method(*test_case.args, **test_case.kwargs) + legacy_exception = None + except Exception as e: # pylint: disable=broad-except + legacy_response = None + legacy_exception = e + + try: + # Execute v2 method + v2_method = self._get_method(self.v2_client, test_case.v2_method_path) + v2_response = v2_method(*test_case.args, **test_case.kwargs) + v2_exception = None + except Exception as e: # pylint: disable=broad-except + v2_response = None + v2_exception = e + + # Handle exception cases + if legacy_exception or v2_exception: + if legacy_exception and v2_exception: + # Both raised exceptions - compare exception types + if isinstance(legacy_exception, type(v2_exception)): + comparison_result = ComparisonResult( + status=ComparisonStatus.EQUIVALENT, + legacy_response=str(legacy_exception), + v2_response=str(v2_exception), + metadata={"both_raised_exception": True}, + ) + else: + comparison_result = ComparisonResult( + status=ComparisonStatus.DIFFERENT, + legacy_response=str(legacy_exception), + v2_response=str(v2_exception), + differences=[ + f"Exception type mismatch: " + f"legacy={type(legacy_exception).__name__}, " + f"v2={type(v2_exception).__name__}" + ], + ) + else: + # Only one raised exception + comparison_result = ComparisonResult( + status=ComparisonStatus.DIFFERENT, + legacy_response=str(legacy_exception) if legacy_exception else legacy_response, + v2_response=str(v2_exception) if v2_exception else v2_response, + differences=[ + f"Exception mismatch: legacy_exception={legacy_exception is not None}, " + f"v2_exception={v2_exception is not None}" + ], + ) + else: + # Compare responses + comparison_result = self.comparator.compare( + legacy_response, v2_response, test_case.method_name + ) + + return TestResult( + test_case=test_case, + comparison_result=comparison_result, + ) + + def run_test_suite(self, test_cases: List[TestCase]) -> List[TestResult]: + """Run a suite of test cases. + + Args: + test_cases: List of test cases to run + + Returns: + List of test results + """ + return [self.run_test_case(tc) for tc in test_cases] + + def _get_method(self, client: Any, method_path: str) -> Callable: + """Get a method from a client by path. + + Args: + client: The client object + method_path: Dot-separated path to method (e.g., "assets.count") + + Returns: + The method callable + + Raises: + AttributeError: If method not found + """ + parts = method_path.split(".") + obj = client + + for part in parts: + obj = getattr(obj, part) + + if not callable(obj): + raise AttributeError(f"'{method_path}' is not callable") + + return obj + + +@dataclass +class TestSuiteReport: + """Report of test suite execution.""" + + total: int + passed: int + failed: int + skipped: int + results: List[TestResult] + + @property + def success_rate(self) -> float: + """Calculate success rate.""" + if self.total == 0: + return 0.0 + return (self.passed / (self.total - self.skipped)) * 100 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "total": self.total, + "passed": self.passed, + "failed": self.failed, + "skipped": self.skipped, + "success_rate": f"{self.success_rate:.2f}%", + "results": [ + { + "name": r.test_case.name, + "passed": r.passed, + "failed": r.failed, + "skipped": r.skipped, + "differences": (r.comparison_result.differences if r.comparison_result else []), + } + for r in self.results + ], + } + + +def generate_report(results: List[TestResult]) -> TestSuiteReport: + """Generate a test suite report. + + Args: + results: List of test results + + Returns: + Test suite report + """ + total = len(results) + passed = sum(1 for r in results if r.passed) + failed = sum(1 for r in results if r.failed) + skipped = sum(1 for r in results if r.skipped) + + return TestSuiteReport( + total=total, + passed=passed, + failed=failed, + skipped=skipped, + results=results, + ) diff --git a/tests/equivalence/normalizer.py b/tests/equivalence/normalizer.py new file mode 100644 index 000000000..4082d0df8 --- /dev/null +++ b/tests/equivalence/normalizer.py @@ -0,0 +1,252 @@ +"""Payload normalization utilities for equivalence testing. + +This module provides utilities to normalize API responses for comparison, +handling differences in structure (dict vs TypedDict) while ensuring +semantic equivalence. +""" + +from typing import Any, Dict, List, Set, Union + + +class PayloadNormalizer: + """Normalize payloads for semantic equivalence comparison. + + This class handles normalization of API responses to enable comparison + between legacy dict-based responses and v2 TypedDict-based responses. + """ + + # Fields to ignore during comparison (timestamps may vary slightly, etc.) + IGNORE_FIELDS: Set[str] = { + "__typename", # GraphQL metadata + "_internal_id", # Internal tracking fields + } + + # Fields that should be sorted for consistent comparison + SORTABLE_LIST_FIELDS: Set[str] = { + "labels", + "assets", + "projects", + "users", + } + + @classmethod + def normalize( + cls, + payload: Any, + sort_lists: bool = True, + strip_none: bool = False, + ) -> Any: + """Normalize a payload for comparison. + + Args: + payload: The payload to normalize (dict, list, or primitive) + sort_lists: Whether to sort lists for consistent ordering + strip_none: Whether to remove None values from dicts + + Returns: + Normalized payload suitable for comparison + """ + if payload is None: + return None + + if isinstance(payload, dict): + return cls._normalize_dict(payload, sort_lists, strip_none) + + if isinstance(payload, (list, tuple)): + return cls._normalize_list(list(payload), sort_lists, strip_none) + + # Primitives (str, int, float, bool) are returned as-is + return payload + + @classmethod + def _normalize_dict( + cls, + data: Dict[str, Any], + sort_lists: bool, + strip_none: bool, + ) -> Dict[str, Any]: + """Normalize a dictionary.""" + normalized = {} + + for key, value in data.items(): + # Skip ignored fields + if key in cls.IGNORE_FIELDS: + continue + + # Skip None values if requested + if strip_none and value is None: + continue + + # Recursively normalize nested structures + normalized[key] = cls.normalize(value, sort_lists, strip_none) + + return normalized + + @classmethod + def _normalize_list( + cls, + data: List[Any], + sort_lists: bool, + strip_none: bool, + ) -> List[Any]: + """Normalize a list.""" + normalized = [cls.normalize(item, sort_lists, strip_none) for item in data] + + # Sort lists if requested and items are sortable + if sort_lists and normalized and cls._is_sortable_list(normalized): + normalized = sorted(normalized, key=cls._sort_key) + + return normalized + + @classmethod + def _is_sortable_list(cls, items: List[Any]) -> bool: + """Check if a list can be sorted.""" + if not items: + return False + + # Lists of dicts with 'id' field can be sorted by id + if isinstance(items[0], dict) and "id" in items[0]: + return True + + # Lists of primitives can be sorted + if isinstance(items[0], (str, int, float, bool)): + return True + + return False + + @classmethod + def _sort_key(cls, item: Any) -> Any: + """Generate a sort key for an item.""" + if isinstance(item, dict): + # Sort by id if available, otherwise by string representation + return item.get("id", str(item)) + return item + + @classmethod + def normalize_for_comparison( + cls, + legacy_response: Any, + v2_response: Any, + ) -> tuple[Any, Any]: + """Normalize both legacy and v2 responses for comparison. + + This method applies consistent normalization to both responses, + handling common differences between legacy and v2 implementations. + + Args: + legacy_response: Response from legacy implementation + v2_response: Response from v2 implementation + + Returns: + Tuple of (normalized_legacy, normalized_v2) + """ + # Apply standard normalization + norm_legacy = cls.normalize(legacy_response, sort_lists=True, strip_none=False) + norm_v2 = cls.normalize(v2_response, sort_lists=True, strip_none=False) + + return norm_legacy, norm_v2 + + +class DiffGenerator: + """Generate human-readable diffs between payloads.""" + + @classmethod + def generate_diff( + cls, + legacy: Any, + v2: Any, + path: str = "root", + max_depth: int = 10, + ) -> List[str]: + """Generate a list of differences between two payloads. + + Args: + legacy: Legacy response + v2: V2 response + path: Current path in the data structure + max_depth: Maximum recursion depth + + Returns: + List of difference descriptions + """ + if max_depth <= 0: + return [f"{path}: Maximum recursion depth reached"] + + diffs: List[str] = [] + + # Type mismatch + if type(legacy) is not type(v2): + diffs.append( + f"{path}: Type mismatch - legacy={type(legacy).__name__}, " + f"v2={type(v2).__name__}" + ) + return diffs + + # Compare dicts + if isinstance(legacy, dict): + diffs.extend(cls._diff_dicts(legacy, v2, path, max_depth)) + + # Compare lists + elif isinstance(legacy, (list, tuple)): + diffs.extend(cls._diff_lists(legacy, v2, path, max_depth)) + + # Compare primitives + elif legacy != v2: + diffs.append(f"{path}: Value mismatch - legacy={legacy!r}, v2={v2!r}") + + return diffs + + @classmethod + def _diff_dicts( + cls, + legacy: Dict[str, Any], + v2: Dict[str, Any], + path: str, + max_depth: int, + ) -> List[str]: + """Generate diffs for dictionaries.""" + diffs: List[str] = [] + + # Keys only in legacy + legacy_only = set(legacy.keys()) - set(v2.keys()) + if legacy_only: + diffs.append(f"{path}: Keys only in legacy: {sorted(legacy_only)}") + + # Keys only in v2 + v2_only = set(v2.keys()) - set(legacy.keys()) + if v2_only: + diffs.append(f"{path}: Keys only in v2: {sorted(v2_only)}") + + # Compare common keys + common_keys = set(legacy.keys()) & set(v2.keys()) + for key in sorted(common_keys): + key_path = f"{path}.{key}" + diffs.extend(cls.generate_diff(legacy[key], v2[key], key_path, max_depth - 1)) + + return diffs + + @classmethod + def _diff_lists( + cls, + legacy: Union[List[Any], tuple], + v2: Union[List[Any], tuple], + path: str, + max_depth: int, + ) -> List[str]: + """Generate diffs for lists.""" + diffs: List[str] = [] + + # Length mismatch + if len(legacy) != len(v2): + diffs.append(f"{path}: Length mismatch - legacy={len(legacy)}, v2={len(v2)}") + # Still compare up to the shorter length + min_len = min(len(legacy), len(v2)) + else: + min_len = len(legacy) + + # Compare elements + for i in range(min_len): + item_path = f"{path}[{i}]" + diffs.extend(cls.generate_diff(legacy[i], v2[i], item_path, max_depth - 1)) + + return diffs diff --git a/tests/equivalence/recorder.py b/tests/equivalence/recorder.py new file mode 100644 index 000000000..fe98118ee --- /dev/null +++ b/tests/equivalence/recorder.py @@ -0,0 +1,259 @@ +"""Request/response recording for equivalence testing. + +This module provides the ability to record API requests and responses +from legacy client methods for later replay against v2 implementations. +""" + +import json +import pickle +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + + +@dataclass +class RecordedRequest: + """A recorded API request with its response. + + This class captures all information needed to replay a request + against a different implementation. + """ + + # Request metadata + timestamp: str + method_name: str + args: tuple = field(default_factory=tuple) + kwargs: Dict[str, Any] = field(default_factory=dict) + + # Response data + response: Any = None + exception: Optional[str] = None + exception_type: Optional[str] = None + + # Context + project_id: Optional[str] = None + asset_id: Optional[str] = None + label_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "timestamp": self.timestamp, + "method_name": self.method_name, + "args": self.args, + "kwargs": self.kwargs, + "response": self.response, + "exception": self.exception, + "exception_type": self.exception_type, + "project_id": self.project_id, + "asset_id": self.asset_id, + "label_id": self.label_id, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RecordedRequest": + """Create from dictionary.""" + return cls( + timestamp=data["timestamp"], + method_name=data["method_name"], + args=tuple(data.get("args", ())), + kwargs=data.get("kwargs", {}), + response=data.get("response"), + exception=data.get("exception"), + exception_type=data.get("exception_type"), + project_id=data.get("project_id"), + asset_id=data.get("asset_id"), + label_id=data.get("label_id"), + ) + + +class RequestRecorder: + """Records API requests and responses for later replay. + + This class provides functionality to record method calls, responses, + and exceptions from legacy client implementations. + """ + + def __init__(self, storage_dir: Optional[Union[str, Path]] = None): + """Initialize the recorder. + + Args: + storage_dir: Directory to store recordings (default: ./recordings) + """ + self.storage_dir = Path(storage_dir or "./recordings") + self.storage_dir.mkdir(parents=True, exist_ok=True) + self.recordings: List[RecordedRequest] = [] + + def record( + self, + method_name: str, + args: tuple = (), + kwargs: Optional[Dict[str, Any]] = None, + response: Any = None, + exception: Optional[Exception] = None, + context: Optional[Dict[str, str]] = None, + ) -> RecordedRequest: + """Record a method call and its response. + + Args: + method_name: Name of the method called + args: Positional arguments + kwargs: Keyword arguments + response: Response from the method + exception: Exception raised (if any) + context: Additional context (project_id, asset_id, etc.) + + Returns: + The recorded request object + """ + kwargs = kwargs or {} + context = context or {} + + # Serialize exception if present + exception_str = None + exception_type = None + if exception: + exception_str = str(exception) + exception_type = type(exception).__name__ + + recording = RecordedRequest( + timestamp=datetime.utcnow().isoformat(), + method_name=method_name, + args=args, + kwargs=kwargs, + response=response, + exception=exception_str, + exception_type=exception_type, + project_id=context.get("project_id"), + asset_id=context.get("asset_id"), + label_id=context.get("label_id"), + ) + + self.recordings.append(recording) + return recording + + def save(self, filename: str, format: str = "json") -> Path: + """Save recordings to file. + + Args: + filename: Name of the file (without extension) + format: Format to save ('json' or 'pickle') + + Returns: + Path to the saved file + """ + if format == "json": + filepath = self.storage_dir / f"{filename}.json" + with open(filepath, "w", encoding="utf-8") as f: + json.dump( + [r.to_dict() for r in self.recordings], + f, + indent=2, + default=str, # Handle non-serializable types + ) + elif format == "pickle": + filepath = self.storage_dir / f"{filename}.pkl" + with open(filepath, "wb") as f: + pickle.dump(self.recordings, f) + else: + raise ValueError(f"Unsupported format: {format}") + + return filepath + + def load(self, filepath: Union[str, Path]) -> List[RecordedRequest]: + """Load recordings from file. + + Args: + filepath: Path to the recordings file + + Returns: + List of recorded requests + """ + filepath = Path(filepath) + + if filepath.suffix == ".json": + with open(filepath, encoding="utf-8") as f: + data = json.load(f) + self.recordings = [RecordedRequest.from_dict(r) for r in data] + elif filepath.suffix == ".pkl": + with open(filepath, "rb") as f: + self.recordings = pickle.load(f) + else: + raise ValueError(f"Unsupported file format: {filepath.suffix}") + + return self.recordings + + def clear(self) -> None: + """Clear all recordings.""" + self.recordings.clear() + + def filter_by_method(self, method_name: str) -> List[RecordedRequest]: + """Filter recordings by method name. + + Args: + method_name: Name of the method to filter by + + Returns: + List of recordings matching the method name + """ + return [r for r in self.recordings if r.method_name == method_name] + + def filter_by_context( + self, + project_id: Optional[str] = None, + asset_id: Optional[str] = None, + label_id: Optional[str] = None, + ) -> List[RecordedRequest]: + """Filter recordings by context. + + Args: + project_id: Filter by project ID + asset_id: Filter by asset ID + label_id: Filter by label ID + + Returns: + List of recordings matching the context + """ + filtered = self.recordings + + if project_id is not None: + filtered = [r for r in filtered if r.project_id == project_id] + + if asset_id is not None: + filtered = [r for r in filtered if r.asset_id == asset_id] + + if label_id is not None: + filtered = [r for r in filtered if r.label_id == label_id] + + return filtered + + def get_summary(self) -> Dict[str, Any]: + """Get a summary of recorded requests. + + Returns: + Dictionary with summary statistics + """ + methods = {} + for recording in self.recordings: + method = recording.method_name + if method not in methods: + methods[method] = { + "count": 0, + "success": 0, + "errors": 0, + } + methods[method]["count"] += 1 + if recording.exception: + methods[method]["errors"] += 1 + else: + methods[method]["success"] += 1 + + return { + "total_recordings": len(self.recordings), + "methods": methods, + "time_range": { + "start": min((r.timestamp for r in self.recordings), default=None), + "end": max((r.timestamp for r in self.recordings), default=None), + }, + } diff --git a/tests/equivalence/test_asset_equivalence.py b/tests/equivalence/test_asset_equivalence.py new file mode 100644 index 000000000..c0393be6d --- /dev/null +++ b/tests/equivalence/test_asset_equivalence.py @@ -0,0 +1,131 @@ +"""Equivalence tests for Asset methods. + +This module tests semantic equivalence between legacy and v2 asset methods. +""" + +import pytest + +from kili.client import Kili +from tests.equivalence.comparator import EquivalenceAssertion +from tests.equivalence.fixtures import ASSET_TEST_CASES +from tests.equivalence.harness import EquivalenceTestHarness, generate_report + + +@pytest.fixture() +def legacy_client(mocker): + """Create a mocked legacy Kili client.""" + # In real tests, this would be a real client or a comprehensive mock + mock_client = mocker.MagicMock(spec=Kili) + + # Mock count_assets + mock_client.count_assets.return_value = 42 + + # Mock assets (returns list) + mock_client.assets.return_value = [ + { + "id": "asset1", + "externalId": "ext1", + "content": "http://example.com/image1.jpg", + "status": "TODO", + }, + { + "id": "asset2", + "externalId": "ext2", + "content": "http://example.com/image2.jpg", + "status": "LABELED", + }, + ] + + return mock_client + + +@pytest.fixture() +def v2_client(mocker): + """Create a mocked v2 Kili client.""" + # In real tests, this would use the actual v2 implementation + mock_client = mocker.MagicMock() + + # Mock assets namespace + mock_assets = mocker.MagicMock() + mock_assets.count.return_value = 42 + mock_assets.list.return_value = [ + { + "id": "asset1", + "externalId": "ext1", + "content": "http://example.com/image1.jpg", + "status": "TODO", + }, + { + "id": "asset2", + "externalId": "ext2", + "content": "http://example.com/image2.jpg", + "status": "LABELED", + }, + ] + mock_client.assets = mock_assets + + return mock_client + + +@pytest.fixture() +def test_harness(legacy_client, v2_client): # pylint: disable=redefined-outer-name + """Create test harness with clients.""" + return EquivalenceTestHarness( + legacy_client=legacy_client, + v2_client=v2_client, + ) + + +class TestAssetEquivalence: + """Test equivalence of asset methods.""" + + def test_count_assets_basic(self, test_harness): # pylint: disable=redefined-outer-name + """Test count_assets returns same value as assets.count.""" + test_case = next(tc for tc in ASSET_TEST_CASES if tc.name == "count_assets_basic") + result = test_harness.run_test_case(test_case) + + diffs = ( + result.comparison_result.differences if result.comparison_result else "No comparison" + ) + assert result.passed, f"Test failed: {diffs}" + + def test_assets_list_basic(self, test_harness): # pylint: disable=redefined-outer-name + """Test assets returns same data as assets.list.""" + test_case = next(tc for tc in ASSET_TEST_CASES if tc.name == "assets_list_basic") + result = test_harness.run_test_case(test_case) + + diffs = ( + result.comparison_result.differences if result.comparison_result else "No comparison" + ) + assert result.passed, f"Test failed: {diffs}" + + @pytest.mark.parametrize("test_case", ASSET_TEST_CASES, ids=lambda tc: tc.name) + def test_all_asset_methods(self, test_harness, test_case): # pylint: disable=redefined-outer-name + """Parametrized test for all asset methods.""" + if test_case.skip: + pytest.skip(test_case.skip_reason) + + result = test_harness.run_test_case(test_case) + + if result.comparison_result: + EquivalenceAssertion.assert_equivalent( + result.comparison_result, + message=f"Asset equivalence test failed: {test_case.name}", + ) + + +def test_asset_suite_report(test_harness): # pylint: disable=redefined-outer-name + """Test generating a report for the full asset test suite.""" + results = test_harness.run_test_suite(ASSET_TEST_CASES) + report = generate_report(results) + + # Print report for debugging + print("\nAsset Equivalence Test Report:") + print(f" Total: {report.total}") + print(f" Passed: {report.passed}") + print(f" Failed: {report.failed}") + print(f" Skipped: {report.skipped}") + print(f" Success Rate: {report.success_rate:.2f}%") + + # Assert overall success + assert report.failed == 0, f"{report.failed} tests failed" diff --git a/tests/equivalence/test_normalizer.py b/tests/equivalence/test_normalizer.py new file mode 100644 index 000000000..f5513f4ce --- /dev/null +++ b/tests/equivalence/test_normalizer.py @@ -0,0 +1,140 @@ +"""Unit tests for payload normalizer.""" + +from tests.equivalence.normalizer import DiffGenerator, PayloadNormalizer + + +class TestPayloadNormalizer: + """Test PayloadNormalizer functionality.""" + + def test_normalize_dict_basic(self): + """Test basic dictionary normalization.""" + data = {"id": "123", "name": "test", "__typename": "Asset"} + normalized = PayloadNormalizer.normalize(data) + + # __typename should be removed + assert "__typename" not in normalized + assert normalized["id"] == "123" + assert normalized["name"] == "test" + + def test_normalize_nested_dict(self): + """Test nested dictionary normalization.""" + data = { + "id": "123", + "user": {"id": "456", "email": "test@example.com", "__typename": "User"}, + } + normalized = PayloadNormalizer.normalize(data) + + assert "__typename" not in normalized["user"] + assert normalized["user"]["id"] == "456" + + def test_normalize_list(self): + """Test list normalization and sorting.""" + data = [ + {"id": "2", "name": "b"}, + {"id": "1", "name": "a"}, + {"id": "3", "name": "c"}, + ] + normalized = PayloadNormalizer.normalize(data, sort_lists=True) + + # Should be sorted by id + assert normalized[0]["id"] == "1" + assert normalized[1]["id"] == "2" + assert normalized[2]["id"] == "3" + + def test_normalize_strip_none(self): + """Test stripping None values.""" + data = {"id": "123", "name": None, "value": "test"} + normalized = PayloadNormalizer.normalize(data, strip_none=True) + + assert "name" not in normalized + assert normalized["id"] == "123" + assert normalized["value"] == "test" + + def test_normalize_for_comparison(self): + """Test normalize_for_comparison with legacy and v2 responses.""" + legacy = { + "id": "123", + "labels": [{"id": "2"}, {"id": "1"}], + "__typename": "Asset", + } + v2 = { + "id": "123", + "labels": [{"id": "1"}, {"id": "2"}], + } + + norm_legacy, norm_v2 = PayloadNormalizer.normalize_for_comparison(legacy, v2) + + # Should be equal after normalization + assert norm_legacy == norm_v2 + + +class TestDiffGenerator: + """Test DiffGenerator functionality.""" + + def test_generate_diff_equal(self): + """Test diff generation for equal objects.""" + legacy = {"id": "123", "name": "test"} + v2 = {"id": "123", "name": "test"} + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 0 + + def test_generate_diff_value_mismatch(self): + """Test diff generation for value mismatch.""" + legacy = {"id": "123", "name": "test"} + v2 = {"id": "123", "name": "different"} + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 1 + assert "name" in diffs[0] + assert "Value mismatch" in diffs[0] + + def test_generate_diff_missing_keys(self): + """Test diff generation for missing keys.""" + legacy = {"id": "123", "name": "test", "extra": "value"} + v2 = {"id": "123", "name": "test"} + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 1 + assert "only in legacy" in diffs[0].lower() + + def test_generate_diff_type_mismatch(self): + """Test diff generation for type mismatch.""" + legacy = {"id": "123"} + v2 = ["123"] + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 1 + assert "Type mismatch" in diffs[0] + + def test_generate_diff_list_length(self): + """Test diff generation for list length mismatch.""" + legacy = [1, 2, 3] + v2 = [1, 2] + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 1 + assert "Length mismatch" in diffs[0] + + def test_generate_diff_nested_objects(self): + """Test diff generation for nested objects.""" + legacy = { + "id": "123", + "user": {"id": "456", "name": "Alice"}, + } + v2 = { + "id": "123", + "user": {"id": "456", "name": "Bob"}, + } + + diffs = DiffGenerator.generate_diff(legacy, v2) + + assert len(diffs) == 1 + assert "user.name" in diffs[0] + assert "Alice" in diffs[0] + assert "Bob" in diffs[0] diff --git a/tests/unit/domain_api/test_assets.py b/tests/unit/domain_api/test_assets.py index adda280e1..e7fe3bb07 100644 --- a/tests/unit/domain_api/test_assets.py +++ b/tests/unit/domain_api/test_assets.py @@ -123,7 +123,7 @@ def test_list_assets_generator(self, assets_namespace): assert hasattr(result, "__iter__") assets_list = list(result) assert len(assets_list) == 2 - assert assets_list[0]["id"] == "asset1" + assert assets_list[0].id == "asset1" mock_asset_use_cases.assert_called_once_with(assets_namespace.gateway) mock_project_use_cases.assert_called_once_with(assets_namespace.gateway) @@ -151,7 +151,7 @@ def test_list_assets_as_list(self, assets_namespace): assert isinstance(result, list) assert len(result) == 2 - assert result[0]["id"] == "asset1" + assert result[0].id == "asset1" def test_count_assets(self, assets_namespace): """Test count method.""" @@ -283,7 +283,8 @@ def test_create_assets(self, assets_namespace, mock_client): external_id_array=["ext1"], ) - assert result == expected_result + assert result.id == "project_123" + assert result.asset_ids == ["asset1", "asset2"] mock_client.append_many_to_dataset.assert_called_once_with( project_id="project_123", content_array=["https://example.com/image.png"], @@ -305,7 +306,7 @@ def test_delete_assets(self, assets_namespace, mock_client): result = assets_namespace.delete(asset_ids=["asset1", "asset2"]) - assert result == expected_result + assert result.id == "project_123" mock_client.delete_many_from_dataset.assert_called_once_with( asset_ids=["asset1", "asset2"], external_ids=None, project_id=None ) @@ -321,7 +322,7 @@ def test_update_assets(self, assets_namespace, mock_client): json_metadatas=[{"key": "value1"}, {"key": "value2"}], ) - assert result == expected_result + assert result.ids == ["asset1", "asset2"] mock_client.update_properties_in_assets.assert_called_once_with( asset_ids=["asset1", "asset2"], external_ids=None, @@ -383,7 +384,7 @@ def test_assign_delegates_to_client(self, workflow_namespace, mock_client): asset_ids=["asset1", "asset2"], to_be_labeled_by_array=[["user1"], ["user2"]] ) - assert result == expected_result + assert result.ids == ["asset1", "asset2"] mock_client.assign_assets_to_labelers.assert_called_once_with( asset_ids=["asset1", "asset2"], external_ids=None, @@ -430,7 +431,8 @@ def test_invalidate_delegates_to_client(self, workflow_step_namespace, mock_clie result = workflow_step_namespace.invalidate(asset_ids=["asset1", "asset2"]) - assert result == expected_result + assert result.id == "project_123" + assert result.asset_ids == ["asset1", "asset2"] mock_client.send_back_to_queue.assert_called_once_with( asset_ids=["asset1", "asset2"], external_ids=None, project_id=None ) @@ -442,7 +444,8 @@ def test_next_delegates_to_client(self, workflow_step_namespace, mock_client): result = workflow_step_namespace.next(asset_ids=["asset1", "asset2"]) - assert result == expected_result + assert result.id == "project_123" + assert result.asset_ids == ["asset1", "asset2"] mock_client.add_to_review.assert_called_once_with( asset_ids=["asset1", "asset2"], external_ids=None, project_id=None ) @@ -487,7 +490,7 @@ def test_update_delegates_to_client(self, external_ids_namespace, mock_client): new_external_ids=["new_ext1", "new_ext2"], asset_ids=["asset1", "asset2"] ) - assert result == expected_result + assert result.ids == ["asset1", "asset2"] mock_client.change_asset_external_ids.assert_called_once_with( new_external_ids=["new_ext1", "new_ext2"], asset_ids=["asset1", "asset2"], @@ -538,7 +541,7 @@ def test_add_delegates_to_client(self, metadata_namespace, mock_client): asset_ids=["asset1", "asset2"], ) - assert result == expected_result + assert result.ids == ["asset1", "asset2"] mock_client.add_metadata.assert_called_once_with( json_metadata=[{"key1": "value1"}, {"key2": "value2"}], project_id="project_123", @@ -557,7 +560,7 @@ def test_set_delegates_to_client(self, metadata_namespace, mock_client): asset_ids=["asset1", "asset2"], ) - assert result == expected_result + assert result.ids == ["asset1", "asset2"] mock_client.set_metadata.assert_called_once_with( json_metadata=[{"key1": "value1"}, {"key2": "value2"}], project_id="project_123", diff --git a/tests/unit/domain_api/test_assets_integration.py b/tests/unit/domain_api/test_assets_integration.py index e73572693..bd649d8ea 100644 --- a/tests/unit/domain_api/test_assets_integration.py +++ b/tests/unit/domain_api/test_assets_integration.py @@ -79,17 +79,17 @@ def test_workflow_operations_delegation(self, mock_kili_client): # Test workflow assign result = assets_ns.workflow.assign(asset_ids=["asset1"], to_be_labeled_by_array=[["user1"]]) - assert result[0]["id"] == "asset1" + assert result.ids[0] == "asset1" mock_kili_client.legacy_client.assign_assets_to_labelers.assert_called_once() # Test workflow step invalidate result = assets_ns.workflow.step.invalidate(asset_ids=["asset1"]) - assert result["id"] == "project_123" + assert result.id == "project_123" mock_kili_client.legacy_client.send_back_to_queue.assert_called_once() # Test workflow step next result = assets_ns.workflow.step.next(asset_ids=["asset1"]) - assert result["id"] == "project_123" + assert result.id == "project_123" mock_kili_client.legacy_client.add_to_review.assert_called_once() def test_metadata_operations_delegation(self, mock_kili_client): @@ -104,14 +104,14 @@ def test_metadata_operations_delegation(self, mock_kili_client): result = assets_ns.metadata.add( json_metadata=[{"key": "value"}], project_id="project_123", asset_ids=["asset1"] ) - assert result[0]["id"] == "asset1" + assert result.ids[0] == "asset1" mock_kili_client.legacy_client.add_metadata.assert_called_once() # Test metadata set result = assets_ns.metadata.set( json_metadata=[{"key": "value"}], project_id="project_123", asset_ids=["asset1"] ) - assert result[0]["id"] == "asset1" + assert result.ids[0] == "asset1" mock_kili_client.legacy_client.set_metadata.assert_called_once() def test_external_ids_operations_delegation(self, mock_kili_client): @@ -125,7 +125,7 @@ def test_external_ids_operations_delegation(self, mock_kili_client): # Test external IDs update result = assets_ns.external_ids.update(new_external_ids=["new_ext1"], asset_ids=["asset1"]) - assert result[0]["id"] == "asset1" + assert result.ids[0] == "asset1" mock_kili_client.legacy_client.change_asset_external_ids.assert_called_once() @patch("kili.domain_api.assets.AssetUseCases") @@ -144,7 +144,7 @@ def test_list_and_count_use_cases_integration(self, mock_asset_use_cases, mock_k result_gen = assets_ns.list(project_id="project_123") assets_list = list(result_gen) assert len(assets_list) == 1 - assert assets_list[0]["id"] == "asset1" + assert assets_list[0].id == "asset1" # Test count assets count = assets_ns.count(project_id="project_123") diff --git a/tests/unit/domain_api/test_connections.py b/tests/unit/domain_api/test_connections.py index 3e7cf92d7..c7a5349a8 100644 --- a/tests/unit/domain_api/test_connections.py +++ b/tests/unit/domain_api/test_connections.py @@ -58,7 +58,9 @@ def test_list_calls_legacy_method(self, mock_legacy_method, connections_namespac disable_tqdm=None, as_generator=False, ) - assert result == [{"id": "conn_123", "projectId": "proj_456"}] + assert len(result) == 1 + assert result[0].id == "conn_123" + assert result[0].project_id == "proj_456" def test_list_parameter_validation(self, connections_namespace): """Test that list validates required parameters.""" @@ -92,7 +94,8 @@ def test_add_calls_legacy_method(self, mock_legacy_method, connections_namespace include=None, exclude=None, ) - assert result == {"id": "conn_789"} + # Check the result is a ConnectionView with correct ID + assert result.id == "conn_789" def test_add_input_validation(self, connections_namespace): """Test that add() validates input parameters.""" @@ -132,7 +135,7 @@ def test_add_error_handling(self, mock_legacy_method, connections_namespace): ) def test_sync_calls_legacy_method(self, mock_legacy_method, connections_namespace): """Test that sync() calls the legacy synchronize_cloud_storage_connection method.""" - mock_legacy_method.return_value = {"numberOfAssets": 42, "projectId": "proj_123"} + mock_legacy_method.return_value = {"id": "conn_789"} result = connections_namespace.sync(connection_id="conn_789", dry_run=True) @@ -142,7 +145,8 @@ def test_sync_calls_legacy_method(self, mock_legacy_method, connections_namespac delete_extraneous_files=False, dry_run=True, ) - assert result == {"numberOfAssets": 42, "projectId": "proj_123"} + # Check the result is an IdResponse with correct ID + assert result.id == "conn_789" def test_sync_input_validation(self, connections_namespace): """Test that sync() validates input parameters.""" diff --git a/tests/unit/domain_api/test_integrations.py b/tests/unit/domain_api/test_integrations.py new file mode 100644 index 000000000..ccfdb177a --- /dev/null +++ b/tests/unit/domain_api/test_integrations.py @@ -0,0 +1,157 @@ +"""Unit tests for IntegrationsNamespace.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway +from kili.client_domain import Kili +from kili.domain_api.integrations import IntegrationsNamespace +from kili.domain_v2.integration import IntegrationView + + +class TestIntegrationsNamespace: + """Unit tests for IntegrationsNamespace.""" + + @pytest.fixture() + def mock_kili_client(self): + """Create a mock Kili client with proper structure.""" + with patch("kili.client.GraphQLClient"), patch("kili.client.HttpClient"), patch( + "kili.client.KiliAPIGateway" + ) as mock_gateway_class, patch("kili.client.ApiKeyUseCases"), patch( + "kili.client.is_api_key_valid" + ), patch.dict("os.environ", {"KILI_SDK_SKIP_CHECKS": "1"}): + mock_gateway = MagicMock(spec=KiliAPIGateway) + mock_gateway_class.return_value = mock_gateway + + client = Kili(api_key="fake_key") + return client + + def test_integrations_namespace_lazy_loading(self, mock_kili_client): + """Test that integrations namespace is lazily loaded and cached.""" + # First access should create the namespace + integrations_ns1 = mock_kili_client.integrations + assert isinstance(integrations_ns1, IntegrationsNamespace) + + # Second access should return the same instance (cached) + integrations_ns2 = mock_kili_client.integrations + assert integrations_ns1 is integrations_ns2 + + @patch("kili.domain_api.integrations.CloudStorageClientMethods") + def test_create_returns_integration_view(self, mock_cloud_storage_methods, mock_kili_client): + """Test that create() returns an IntegrationView object.""" + # Mock the legacy method to return a dict (with camelCase fields as returned by API) + mock_cloud_storage_methods.create_cloud_storage_integration.return_value = { + "id": "integration_123", + "name": "Test Integration", + "platform": "AWS", + "status": "CONNECTED", + "organizationId": "org_456", + } + + integrations_ns = mock_kili_client.integrations + + # Call create + result = integrations_ns.create( + platform="AWS", + name="Test Integration", + s3_bucket_name="test-bucket", + s3_region="us-east-1", + s3_access_key="AKIAIOSFODNN7EXAMPLE", + s3_secret_key="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + ) + + # Verify result is an IntegrationView + assert isinstance(result, IntegrationView) + assert result.id == "integration_123" + assert result.name == "Test Integration" + assert result.platform == "AWS" + assert result.status == "CONNECTED" + assert result.organization_id == "org_456" + + # Verify the legacy method was called + mock_cloud_storage_methods.create_cloud_storage_integration.assert_called_once() + + @patch("kili.domain_api.integrations.CloudStorageClientMethods") + def test_update_returns_integration_view(self, mock_cloud_storage_methods, mock_kili_client): + """Test that update() returns an IntegrationView object.""" + # Mock the legacy method to return a dict (with camelCase fields as returned by API) + mock_cloud_storage_methods.update_cloud_storage_integration.return_value = { + "id": "integration_123", + "name": "Updated Integration", + "platform": "AWS", + "status": "CONNECTED", + "organizationId": "org_456", + } + + integrations_ns = mock_kili_client.integrations + + # Call update + result = integrations_ns.update( + integration_id="integration_123", + name="Updated Integration", + allowed_paths=["/data/training", "/data/validation"], + ) + + # Verify result is an IntegrationView + assert isinstance(result, IntegrationView) + assert result.id == "integration_123" + assert result.name == "Updated Integration" + assert result.platform == "AWS" + assert result.status == "CONNECTED" + assert result.organization_id == "org_456" + + # Verify the legacy method was called + mock_cloud_storage_methods.update_cloud_storage_integration.assert_called_once() + + @patch("kili.domain_api.integrations.CloudStorageClientMethods") + def test_list_returns_integration_views(self, mock_cloud_storage_methods, mock_kili_client): + """Test that list() returns IntegrationView objects.""" + # Mock the legacy method to return a list of dicts (with camelCase fields as returned by API) + mock_cloud_storage_methods.cloud_storage_integrations.return_value = [ + { + "id": "integration_1", + "name": "Integration 1", + "platform": "AWS", + "status": "CONNECTED", + "organizationId": "org_123", + }, + { + "id": "integration_2", + "name": "Integration 2", + "platform": "AZURE", + "status": "CONNECTED", + "organizationId": "org_123", + }, + ] + + integrations_ns = mock_kili_client.integrations + + # Call list + result = integrations_ns.list(as_generator=False) + + # Verify result is a list of IntegrationView objects + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(item, IntegrationView) for item in result) + assert result[0].id == "integration_1" + assert result[0].platform == "AWS" + assert result[1].id == "integration_2" + assert result[1].platform == "AZURE" + + # Verify the legacy method was called + mock_cloud_storage_methods.cloud_storage_integrations.assert_called_once() + + def test_namespace_inheritance(self, mock_kili_client): + """Test that IntegrationsNamespace properly inherits from DomainNamespace.""" + integrations_ns = mock_kili_client.integrations + + # Test DomainNamespace properties + assert hasattr(integrations_ns, "client") + assert hasattr(integrations_ns, "gateway") + assert hasattr(integrations_ns, "domain_name") + assert integrations_ns.domain_name == "integrations" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/unit/domain_v2/__init__.py b/tests/unit/domain_v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/domain_v2/test_adapters.py b/tests/unit/domain_v2/test_adapters.py new file mode 100644 index 000000000..8dd8183e4 --- /dev/null +++ b/tests/unit/domain_v2/test_adapters.py @@ -0,0 +1,298 @@ +"""Unit tests for DataFrame adapters and validators.""" + +from typing import cast + +import pytest + +try: + import pandas as pd + + PANDAS_AVAILABLE = True +except ImportError: + PANDAS_AVAILABLE = False + pd = None # type: ignore[assignment] + +from kili.domain_v2.adapters import ContractValidator, DataFrameAdapter +from kili.domain_v2.asset import AssetContract, AssetView +from kili.domain_v2.label import LabelContract + + +@pytest.mark.skipif(not PANDAS_AVAILABLE, reason="pandas not installed") +class TestDataFrameAdapter: + """Test suite for DataFrameAdapter.""" + + def test_to_dataframe_with_assets(self): + """Test converting asset contracts to DataFrame.""" + assets = [ + cast(AssetContract, {"id": "asset-1", "externalId": "ext-1", "content": "url1"}), + cast(AssetContract, {"id": "asset-2", "externalId": "ext-2", "content": "url2"}), + ] + + adapter = DataFrameAdapter() + df = adapter.to_dataframe(assets, AssetContract, validate=False) + + assert pd is not None + assert isinstance(df, pd.DataFrame) + assert len(df) == 2 + assert list(df["id"]) == ["asset-1", "asset-2"] + assert list(df["externalId"]) == ["ext-1", "ext-2"] + + def test_to_dataframe_empty_list(self): + """Test converting empty list to DataFrame.""" + adapter = DataFrameAdapter() + df = adapter.to_dataframe([], AssetContract, validate=False) + + assert pd is not None + assert isinstance(df, pd.DataFrame) + assert len(df) == 0 + + def test_from_dataframe_with_assets(self): + """Test converting DataFrame to asset contracts.""" + assert pd is not None + df = pd.DataFrame( + [ + {"id": "asset-1", "externalId": "ext-1"}, + {"id": "asset-2", "externalId": "ext-2"}, + ] + ) + + adapter = DataFrameAdapter() + contracts = adapter.from_dataframe(df, AssetContract, validate=False) + + assert len(contracts) == 2 + assert contracts[0].get("id") == "asset-1" + assert contracts[1].get("id") == "asset-2" + + def test_roundtrip_conversion(self): + """Test converting to DataFrame and back.""" + original_assets = [ + cast( + AssetContract, + { + "id": "asset-1", + "externalId": "ext-1", + "content": "url1", + "isHoneypot": False, + }, + ), + cast( + AssetContract, + { + "id": "asset-2", + "externalId": "ext-2", + "content": "url2", + "isHoneypot": True, + }, + ), + ] + + adapter = DataFrameAdapter() + + # To DataFrame + df = adapter.to_dataframe(original_assets, AssetContract, validate=False) + + # Back to contracts + result = adapter.from_dataframe(df, AssetContract, validate=False) + + assert len(result) == 2 + assert result[0].get("id") == original_assets[0].get("id") + assert result[0].get("externalId") == original_assets[0].get("externalId") + assert result[1].get("isHoneypot") == original_assets[1].get("isHoneypot") + + def test_wrap_contracts_with_asset_view(self): + """Test wrapping contracts in AssetView.""" + contracts = [ + cast(AssetContract, {"id": "asset-1", "externalId": "ext-1"}), + cast(AssetContract, {"id": "asset-2", "externalId": "ext-2"}), + ] + + adapter = DataFrameAdapter() + views = adapter.wrap_contracts(contracts, AssetView) + + assert len(views) == 2 + assert isinstance(views[0], AssetView) + assert isinstance(views[1], AssetView) + assert views[0].id == "asset-1" + assert views[1].display_name == "ext-2" + + def test_unwrap_views(self): + """Test unwrapping views back to dictionaries.""" + contracts = [ + cast(AssetContract, {"id": "asset-1", "externalId": "ext-1"}), + cast(AssetContract, {"id": "asset-2", "externalId": "ext-2"}), + ] + + adapter = DataFrameAdapter() + views = adapter.wrap_contracts(contracts, AssetView) + unwrapped = adapter.unwrap_views(views) + + assert len(unwrapped) == 2 + assert unwrapped[0].get("id") == "asset-1" + assert unwrapped[1].get("externalId") == "ext-2" + + def test_to_dataframe_does_not_mutate_original(self): + """Test that DataFrame conversion doesn't mutate original data.""" + original_assets = [ + cast(AssetContract, {"id": "asset-1", "externalId": "ext-1"}), + ] + + adapter = DataFrameAdapter() + df = adapter.to_dataframe(original_assets, AssetContract, validate=False) + + # Modify DataFrame + df.loc[0, "id"] = "modified-id" + + # Original should be unchanged + assert original_assets[0].get("id") == "asset-1" + + def test_to_dataframe_with_nested_data(self): + """Test converting contracts with nested structures.""" + labels = [ + cast( + LabelContract, + { + "id": "label-1", + "author": {"id": "user-1", "email": "user@example.com"}, + "jsonResponse": {"job": "value"}, + }, + ), + ] + + adapter = DataFrameAdapter() + df = adapter.to_dataframe(labels, LabelContract, validate=False) + + assert pd is not None + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert df.loc[0, "id"] == "label-1" + assert isinstance(df.loc[0, "author"], dict) + assert df.loc[0, "author"]["email"] == "user@example.com" + + +class TestContractValidator: + """Test suite for ContractValidator.""" + + def test_validate_single_valid_asset(self): + """Test validating a single valid asset.""" + asset_data = { + "id": "asset-1", + "externalId": "ext-1", + } + + validator = ContractValidator() + result = validator.validate_single(asset_data, AssetContract) + + assert isinstance(result, dict) + assert result.get("id") == "asset-1" + + def test_validate_batch_all_valid(self): + """Test batch validation with all valid contracts.""" + assets = [ + {"id": "asset-1", "externalId": "ext-1"}, + {"id": "asset-2", "externalId": "ext-2"}, + ] + + validator = ContractValidator() + valid, errors = validator.validate_batch(assets, AssetContract) + + assert len(valid) == 2 + assert len(errors) == 0 + assert valid[0].get("id") == "asset-1" + assert valid[1].get("id") == "asset-2" + + def test_validate_batch_with_errors(self): + """Test batch validation with some invalid contracts.""" + # Mix of valid and potentially invalid data + mixed_data = [ + {"id": "asset-1", "externalId": "ext-1"}, # Valid + {"id": 123}, # Potentially invalid (id should be string) + {"id": "asset-3", "externalId": "ext-3"}, # Valid + ] + + validator = ContractValidator() + valid, errors = validator.validate_batch(mixed_data, AssetContract) + + # With total=False, this might still be valid + # The test shows error handling capability + assert len(valid) + len(errors) == len(mixed_data) + + def test_validate_batch_empty_list(self): + """Test batch validation with empty list.""" + validator = ContractValidator() + valid, errors = validator.validate_batch([], AssetContract) + + assert len(valid) == 0 + assert len(errors) == 0 + + def test_validate_single_with_label_contract(self): + """Test validating a label contract.""" + label_data = { + "id": "label-1", + "author": {"email": "user@example.com"}, + "jsonResponse": {}, + } + + validator = ContractValidator() + result = validator.validate_single(label_data, LabelContract) + + assert isinstance(result, dict) + assert result.get("id") == "label-1" + + def test_validate_batch_labels(self): + """Test batch validation with label contracts.""" + labels = [ + {"id": "label-1", "jsonResponse": {}}, + {"id": "label-2", "jsonResponse": {"job": "value"}}, + ] + + validator = ContractValidator() + valid, errors = validator.validate_batch(labels, LabelContract) + + assert len(valid) == 2 + assert len(errors) == 0 + + +@pytest.mark.skipif(not PANDAS_AVAILABLE, reason="pandas not installed") +class TestDataFrameAdapterPerformance: + """Performance-related tests for DataFrame adapter.""" + + def test_large_dataset_conversion(self): + """Test converting a larger dataset to/from DataFrame.""" + # Create 1000 assets + assets = [ + cast( + AssetContract, + { + "id": f"asset-{i}", + "externalId": f"ext-{i}", + "content": f"url-{i}", + }, + ) + for i in range(1000) + ] + + adapter = DataFrameAdapter() + + # Convert to DataFrame + df = adapter.to_dataframe(assets, AssetContract, validate=False) + assert len(df) == 1000 + + # Convert back + result = adapter.from_dataframe(df, AssetContract, validate=False) + assert len(result) == 1000 + assert result[0].get("id") == "asset-0" + assert result[999].get("id") == "asset-999" + + def test_wrap_large_dataset(self): + """Test wrapping a large dataset in views.""" + assets = [ + cast(AssetContract, {"id": f"asset-{i}", "externalId": f"ext-{i}"}) for i in range(1000) + ] + + adapter = DataFrameAdapter() + views = adapter.wrap_contracts(assets, AssetView) + + assert len(views) == 1000 + assert all(isinstance(v, AssetView) for v in views) + assert views[0].id == "asset-0" + assert views[999].id == "asset-999" diff --git a/tests/unit/domain_v2/test_asset.py b/tests/unit/domain_v2/test_asset.py new file mode 100644 index 000000000..13567eca0 --- /dev/null +++ b/tests/unit/domain_v2/test_asset.py @@ -0,0 +1,221 @@ +"""Unit tests for Asset domain contracts.""" + +from typing import cast + +from kili.domain_v2.asset import AssetContract, AssetView, validate_asset + + +class TestAssetContract: + """Test suite for AssetContract.""" + + def test_validate_asset_with_valid_data(self): + """Test validating a valid asset contract.""" + asset_data = { + "id": "asset-123", + "externalId": "ext-123", + "content": "https://example.com/image.jpg", + "jsonMetadata": {"key": "value"}, + "labels": [], + "status": "TODO", + "isHoneypot": False, + "skipped": False, + "createdAt": "2024-01-01T00:00:00Z", + } + + # Should not raise + result = validate_asset(asset_data) + assert result == asset_data + + def test_validate_asset_with_partial_data(self): + """Test validating an asset with only some fields.""" + asset_data = { + "id": "asset-123", + "externalId": "ext-123", + } + + # Should not raise since TypedDict has total=False + result = validate_asset(asset_data) + assert result == asset_data + + def test_validate_asset_with_nested_labels(self): + """Test validating an asset with nested label data.""" + asset_data = { + "id": "asset-123", + "externalId": "ext-123", + "labels": [ + { + "id": "label-1", + "author": {"id": "user-1", "email": "user@example.com"}, + "jsonResponse": {"job": "value"}, + "createdAt": "2024-01-01T00:00:00Z", + } + ], + "latestLabel": { + "id": "label-1", + "author": {"id": "user-1", "email": "user@example.com"}, + "jsonResponse": {"job": "value"}, + }, + } + + result = validate_asset(asset_data) + assert result == asset_data + assert len(result.get("labels", [])) == 1 + + def test_validate_asset_with_current_step(self): + """Test validating an asset with workflow v2 current step.""" + asset_data = { + "id": "asset-123", + "currentStep": { + "name": "Labeling", + "status": "TO_DO", + }, + } + + result = validate_asset(asset_data) + assert result == asset_data + current_step = result.get("currentStep") + assert current_step is not None + assert current_step.get("name") == "Labeling" + + +class TestAssetView: + """Test suite for AssetView wrapper.""" + + def test_asset_view_basic_properties(self): + """Test basic property access on AssetView.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "externalId": "ext-123", + "content": "https://example.com/image.jpg", + "isHoneypot": True, + "skipped": False, + }, + ) + + view = AssetView(asset_data) + + assert view.id == "asset-123" + assert view.external_id == "ext-123" + assert view.content == "https://example.com/image.jpg" + assert view.is_honeypot is True + assert view.skipped is False + + def test_asset_view_display_name(self): + """Test display name property.""" + # With external ID + asset_data = cast(AssetContract, {"id": "asset-123", "externalId": "ext-123"}) + view = AssetView(asset_data) + assert view.display_name == "ext-123" + + # Without external ID + asset_data = cast(AssetContract, {"id": "asset-123"}) + view = AssetView(asset_data) + assert view.display_name == "asset-123" + + def test_asset_view_labels(self): + """Test label-related properties.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "labels": [ + {"id": "label-1", "jsonResponse": {}}, + {"id": "label-2", "jsonResponse": {}}, + ], + }, + ) + + view = AssetView(asset_data) + + assert view.has_labels is True + assert view.label_count == 2 + assert len(view.labels) == 2 + + def test_asset_view_no_labels(self): + """Test asset view with no labels.""" + asset_data = cast(AssetContract, {"id": "asset-123"}) + view = AssetView(asset_data) + + assert view.has_labels is False + assert view.label_count == 0 + assert view.labels == [] + + def test_asset_view_metadata(self): + """Test metadata property access.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "jsonMetadata": {"custom_field": "custom_value", "priority": 1}, + }, + ) + + view = AssetView(asset_data) + + assert view.metadata is not None + assert view.metadata["custom_field"] == "custom_value" + assert view.metadata["priority"] == 1 + + def test_asset_view_current_step(self): + """Test current step property for workflow v2.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "currentStep": { + "name": "Review", + "status": "DONE", + }, + }, + ) + + view = AssetView(asset_data) + + assert view.current_step is not None + assert view.current_step.get("name") == "Review" + assert view.current_step.get("status") == "DONE" + + def test_asset_view_status(self): + """Test status property for workflow v1.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "status": "LABELED", + }, + ) + + view = AssetView(asset_data) + assert view.status == "LABELED" + + def test_asset_view_to_dict(self): + """Test converting view back to dictionary.""" + asset_data = cast( + AssetContract, + { + "id": "asset-123", + "externalId": "ext-123", + "content": "test", + }, + ) + + view = AssetView(asset_data) + result = view.to_dict() + + assert result == asset_data + assert result is asset_data # Should be the same object + + def test_asset_view_missing_fields(self): + """Test accessing missing fields returns appropriate defaults.""" + asset_data = cast(AssetContract, {"id": "asset-123"}) + view = AssetView(asset_data) + + assert view.external_id == "" + assert view.content == "" + assert view.metadata is None + assert view.latest_label is None + assert view.status is None + assert view.current_step is None + assert view.created_at is None diff --git a/tests/unit/domain_v2/test_label.py b/tests/unit/domain_v2/test_label.py new file mode 100644 index 000000000..d2c50448e --- /dev/null +++ b/tests/unit/domain_v2/test_label.py @@ -0,0 +1,310 @@ +"""Unit tests for Label domain contracts.""" + +from typing import cast + +from kili.domain_v2.label import ( + LabelContract, + LabelView, + filter_labels_by_type, + sort_labels_by_created_at, + validate_label, +) + + +class TestLabelContract: + """Test suite for LabelContract.""" + + def test_validate_label_with_valid_data(self): + """Test validating a valid label contract.""" + label_data = { + "id": "label-123", + "author": {"id": "user-1", "email": "user@example.com", "name": "John Doe"}, + "jsonResponse": {"job1": {"categories": [{"name": "CAT_A"}]}}, + "createdAt": "2024-01-01T00:00:00Z", + "labelType": "DEFAULT", + "isLatestLabelForUser": True, + } + + result = validate_label(label_data) + assert result == label_data + + def test_validate_label_with_partial_data(self): + """Test validating a label with only some fields.""" + label_data = { + "id": "label-123", + "jsonResponse": {}, + } + + result = validate_label(label_data) + assert result == label_data + + def test_validate_label_with_prediction_data(self): + """Test validating a prediction label.""" + label_data = { + "id": "label-123", + "labelType": "PREDICTION", + "modelName": "model-v1", + "jsonResponse": {"predictions": []}, + } + + result = validate_label(label_data) + assert result == label_data + assert result.get("labelType") == "PREDICTION" + + +class TestLabelView: + """Test suite for LabelView wrapper.""" + + def test_label_view_basic_properties(self): + """Test basic property access on LabelView.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "author": {"id": "user-1", "email": "user@example.com"}, + "jsonResponse": {"job": "value"}, + "createdAt": "2024-01-01T00:00:00Z", + "labelType": "DEFAULT", + }, + ) + + view = LabelView(label_data) + + assert view.id == "label-123" + assert view.author_email == "user@example.com" + assert view.author_id == "user-1" + assert view.created_at == "2024-01-01T00:00:00Z" + assert view.label_type == "DEFAULT" + + def test_label_view_author_properties(self): + """Test author-related properties.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "author": { + "id": "user-1", + "email": "user@example.com", + "name": "John Doe", + }, + }, + ) + + view = LabelView(label_data) + + assert view.author is not None + assert view.author.get("email") == "user@example.com" + assert view.author.get("name") == "John Doe" + assert view.author_email == "user@example.com" + assert view.author_id == "user-1" + + def test_label_view_missing_author(self): + """Test label view with no author.""" + label_data = cast(LabelContract, {"id": "label-123"}) + view = LabelView(label_data) + + assert view.author is None + assert view.author_email == "" + assert view.author_id == "" + + def test_label_view_display_name(self): + """Test display name property.""" + # With author email + label_data = cast( + LabelContract, + { + "id": "label-123", + "author": {"email": "user@example.com"}, + }, + ) + view = LabelView(label_data) + assert view.display_name == "user@example.com" + + # Without author + label_data = cast(LabelContract, {"id": "label-123"}) + view = LabelView(label_data) + assert view.display_name == "label-123" + + def test_label_view_is_prediction(self): + """Test is_prediction property.""" + # Prediction type + label_data = cast(LabelContract, {"id": "label-123", "labelType": "PREDICTION"}) + view = LabelView(label_data) + assert view.is_prediction is True + + # Inference type + label_data = cast(LabelContract, {"id": "label-123", "labelType": "INFERENCE"}) + view = LabelView(label_data) + assert view.is_prediction is True + + # Default type + label_data = cast(LabelContract, {"id": "label-123", "labelType": "DEFAULT"}) + view = LabelView(label_data) + assert view.is_prediction is False + + def test_label_view_is_review(self): + """Test is_review property.""" + # Review type + label_data = cast(LabelContract, {"id": "label-123", "labelType": "REVIEW"}) + view = LabelView(label_data) + assert view.is_review is True + + # Default type + label_data = cast(LabelContract, {"id": "label-123", "labelType": "DEFAULT"}) + view = LabelView(label_data) + assert view.is_review is False + + def test_label_view_quality_marks(self): + """Test quality mark properties.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "consensusMark": 0.95, + "honeypotMark": 0.87, + }, + ) + + view = LabelView(label_data) + + assert view.consensus_mark == 0.95 + assert view.honeypot_mark == 0.87 + + def test_label_view_timing_properties(self): + """Test timing-related properties.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "secondsToLabel": 120, + "createdAt": "2024-01-01T00:00:00Z", + "updatedAt": "2024-01-01T00:05:00Z", + }, + ) + + view = LabelView(label_data) + + assert view.seconds_to_label == 120 + assert view.created_at == "2024-01-01T00:00:00Z" + assert view.updated_at == "2024-01-01T00:05:00Z" + + def test_label_view_json_response(self): + """Test JSON response property.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "jsonResponse": { + "JOB_1": {"categories": [{"name": "CAT_A"}]}, + "JOB_2": {"text": "Some text"}, + }, + }, + ) + + view = LabelView(label_data) + + assert "JOB_1" in view.json_response + assert "JOB_2" in view.json_response + assert view.json_response["JOB_1"]["categories"][0]["name"] == "CAT_A" + + def test_label_view_to_dict(self): + """Test converting view back to dictionary.""" + label_data = cast( + LabelContract, + { + "id": "label-123", + "jsonResponse": {}, + "labelType": "DEFAULT", + }, + ) + + view = LabelView(label_data) + result = view.to_dict() + + assert result == label_data + assert result is label_data + + +class TestLabelHelpers: + """Test suite for label helper functions.""" + + def test_sort_labels_by_created_at_ascending(self): + """Test sorting labels by creation time in ascending order.""" + labels = [ + cast(LabelContract, {"id": "label-3", "createdAt": "2024-01-03T00:00:00Z"}), + cast(LabelContract, {"id": "label-1", "createdAt": "2024-01-01T00:00:00Z"}), + cast(LabelContract, {"id": "label-2", "createdAt": "2024-01-02T00:00:00Z"}), + ] + + sorted_labels = sort_labels_by_created_at(labels, reverse=False) + + assert sorted_labels[0].get("id") == "label-1" + assert sorted_labels[1].get("id") == "label-2" + assert sorted_labels[2].get("id") == "label-3" + + def test_sort_labels_by_created_at_descending(self): + """Test sorting labels by creation time in descending order.""" + labels = [ + cast(LabelContract, {"id": "label-1", "createdAt": "2024-01-01T00:00:00Z"}), + cast(LabelContract, {"id": "label-3", "createdAt": "2024-01-03T00:00:00Z"}), + cast(LabelContract, {"id": "label-2", "createdAt": "2024-01-02T00:00:00Z"}), + ] + + sorted_labels = sort_labels_by_created_at(labels, reverse=True) + + assert sorted_labels[0].get("id") == "label-3" + assert sorted_labels[1].get("id") == "label-2" + assert sorted_labels[2].get("id") == "label-1" + + def test_sort_labels_with_missing_created_at(self): + """Test sorting labels when some lack createdAt.""" + labels = [ + cast(LabelContract, {"id": "label-2", "createdAt": "2024-01-02T00:00:00Z"}), + cast(LabelContract, {"id": "label-no-date"}), + cast(LabelContract, {"id": "label-1", "createdAt": "2024-01-01T00:00:00Z"}), + ] + + sorted_labels = sort_labels_by_created_at(labels) + + # Label without date should come first (empty string sorts first) + assert sorted_labels[0].get("id") == "label-no-date" + + def test_filter_labels_by_type_default(self): + """Test filtering labels by DEFAULT type.""" + labels = [ + cast(LabelContract, {"id": "label-1", "labelType": "DEFAULT"}), + cast(LabelContract, {"id": "label-2", "labelType": "REVIEW"}), + cast(LabelContract, {"id": "label-3", "labelType": "DEFAULT"}), + cast(LabelContract, {"id": "label-4", "labelType": "PREDICTION"}), + ] + + filtered = filter_labels_by_type(labels, "DEFAULT") + + assert len(filtered) == 2 + assert filtered[0].get("id") == "label-1" + assert filtered[1].get("id") == "label-3" + + def test_filter_labels_by_type_review(self): + """Test filtering labels by REVIEW type.""" + labels = [ + cast(LabelContract, {"id": "label-1", "labelType": "DEFAULT"}), + cast(LabelContract, {"id": "label-2", "labelType": "REVIEW"}), + cast(LabelContract, {"id": "label-3", "labelType": "REVIEW"}), + ] + + filtered = filter_labels_by_type(labels, "REVIEW") + + assert len(filtered) == 2 + assert filtered[0].get("id") == "label-2" + assert filtered[1].get("id") == "label-3" + + def test_filter_labels_by_type_no_matches(self): + """Test filtering when no labels match the type.""" + labels = [ + cast(LabelContract, {"id": "label-1", "labelType": "DEFAULT"}), + cast(LabelContract, {"id": "label-2", "labelType": "DEFAULT"}), + ] + + filtered = filter_labels_by_type(labels, "REVIEW") + + assert len(filtered) == 0 diff --git a/tests/unit/domain_v2/test_project.py b/tests/unit/domain_v2/test_project.py new file mode 100644 index 000000000..ebede1d9f --- /dev/null +++ b/tests/unit/domain_v2/test_project.py @@ -0,0 +1,395 @@ +"""Unit tests for Project domain contracts.""" + +from typing import cast + +from kili.domain_v2.project import ( + ProjectContract, + ProjectView, + get_ordered_steps, + get_step_by_name, + validate_project, +) + + +class TestProjectContract: + """Test suite for ProjectContract.""" + + def test_validate_project_with_valid_data(self): + """Test validating a valid project contract.""" + project_data = { + "id": "project-123", + "title": "My Project", + "description": "Test project", + "inputType": "IMAGE", + "jsonInterface": {}, + "workflowVersion": "V2", + "steps": [], + "roles": [], + "numberOfAssets": 100, + "archived": False, + "createdAt": "2024-01-01T00:00:00Z", + } + + result = validate_project(project_data) + assert result == project_data + + def test_validate_project_with_partial_data(self): + """Test validating a project with only some fields.""" + project_data = { + "id": "project-123", + "title": "My Project", + } + + result = validate_project(project_data) + assert result == project_data + + def test_validate_project_with_workflow_steps(self): + """Test validating a project with workflow v2 steps.""" + project_data = { + "id": "project-123", + "title": "My Project", + "workflowVersion": "V2", + "steps": [ + { + "id": "step-1", + "name": "Labeling", + "type": "DEFAULT", + "order": 0, + }, + { + "id": "step-2", + "name": "Review", + "type": "REVIEW", + "order": 1, + }, + ], + } + + result = validate_project(project_data) + assert result == project_data + assert len(result.get("steps", [])) == 2 + + +class TestProjectView: + """Test suite for ProjectView wrapper.""" + + def test_project_view_basic_properties(self): + """Test basic property access on ProjectView.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "title": "My Project", + "description": "Test description", + "inputType": "IMAGE", + "numberOfAssets": 100, + "archived": False, + "starred": True, + }, + ) + + view = ProjectView(project_data) + + assert view.id == "project-123" + assert view.title == "My Project" + assert view.description == "Test description" + assert view.input_type == "IMAGE" + assert view.number_of_assets == 100 + assert view.archived is False + assert view.starred is True + + def test_project_view_display_name(self): + """Test display name property.""" + # With title + project_data = cast(ProjectContract, {"id": "project-123", "title": "My Project"}) + view = ProjectView(project_data) + assert view.display_name == "My Project" + + # Without title + project_data = cast(ProjectContract, {"id": "project-123"}) + view = ProjectView(project_data) + assert view.display_name == "project-123" + + def test_project_view_workflow_version(self): + """Test workflow version properties.""" + # V2 workflow + project_data = cast(ProjectContract, {"id": "project-123", "workflowVersion": "V2"}) + view = ProjectView(project_data) + assert view.workflow_version == "V2" + assert view.is_v2_workflow is True + + # V1 workflow + project_data = cast(ProjectContract, {"id": "project-123", "workflowVersion": "V1"}) + view = ProjectView(project_data) + assert view.is_v2_workflow is False + + def test_project_view_steps(self): + """Test steps property.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "steps": [ + {"id": "step-1", "name": "Label", "order": 0}, + {"id": "step-2", "name": "Review", "order": 1}, + ], + }, + ) + + view = ProjectView(project_data) + + assert len(view.steps) == 2 + assert view.steps[0].get("name") == "Label" + assert view.steps[1].get("name") == "Review" + + def test_project_view_roles(self): + """Test roles property.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "roles": [ + {"id": "role-1", "role": "ADMIN", "user": {"id": "user-1"}}, + {"id": "role-2", "role": "LABELER", "user": {"id": "user-2"}}, + ], + }, + ) + + view = ProjectView(project_data) + + assert len(view.roles) == 2 + assert view.roles[0].get("role") == "ADMIN" + + def test_project_view_asset_counts(self): + """Test asset count properties.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "numberOfAssets": 100, + "numberOfRemainingAssets": 30, + "numberOfReviewedAssets": 50, + }, + ) + + view = ProjectView(project_data) + + assert view.number_of_assets == 100 + assert view.number_of_remaining_assets == 30 + assert view.number_of_reviewed_assets == 50 + + def test_project_view_progress_percentage(self): + """Test progress percentage calculation.""" + # 70% complete + project_data = cast( + ProjectContract, + { + "id": "project-123", + "numberOfAssets": 100, + "numberOfRemainingAssets": 30, + }, + ) + view = ProjectView(project_data) + assert view.progress_percentage == 70.0 + + # Empty project + project_data = cast( + ProjectContract, + { + "id": "project-123", + "numberOfAssets": 0, + "numberOfRemainingAssets": 0, + }, + ) + view = ProjectView(project_data) + assert view.progress_percentage == 0.0 + + # Fully complete + project_data = cast( + ProjectContract, + { + "id": "project-123", + "numberOfAssets": 100, + "numberOfRemainingAssets": 0, + }, + ) + view = ProjectView(project_data) + assert view.progress_percentage == 100.0 + + def test_project_view_has_honeypot(self): + """Test honeypot property.""" + project_data = cast(ProjectContract, {"id": "project-123", "useHoneypot": True}) + view = ProjectView(project_data) + assert view.has_honeypot is True + + project_data = cast(ProjectContract, {"id": "project-123", "useHoneypot": False}) + view = ProjectView(project_data) + assert view.has_honeypot is False + + def test_project_view_timestamps(self): + """Test timestamp properties.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "createdAt": "2024-01-01T00:00:00Z", + "updatedAt": "2024-01-15T10:30:00Z", + }, + ) + + view = ProjectView(project_data) + + assert view.created_at == "2024-01-01T00:00:00Z" + assert view.updated_at == "2024-01-15T10:30:00Z" + + def test_project_view_json_interface(self): + """Test JSON interface property.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "jsonInterface": { + "jobs": { + "JOB_1": { + "mlTask": "CLASSIFICATION", + "content": {"categories": {"CAT_A": {}}}, + } + } + }, + }, + ) + + view = ProjectView(project_data) + + assert "jobs" in view.json_interface + assert "JOB_1" in view.json_interface["jobs"] + + def test_project_view_to_dict(self): + """Test converting view back to dictionary.""" + project_data = cast( + ProjectContract, + { + "id": "project-123", + "title": "My Project", + "inputType": "IMAGE", + }, + ) + + view = ProjectView(project_data) + result = view.to_dict() + + assert result == project_data + assert result is project_data + + def test_project_view_missing_fields(self): + """Test accessing missing fields returns appropriate defaults.""" + project_data = cast(ProjectContract, {"id": "project-123"}) + view = ProjectView(project_data) + + assert view.title == "" + assert view.description == "" + assert view.input_type is None + assert view.workflow_version is None + assert view.steps == [] + assert view.roles == [] + assert view.number_of_assets == 0 + assert view.created_at is None + assert view.updated_at is None + assert view.archived is False + assert view.starred is False + + +class TestProjectHelpers: + """Test suite for project helper functions.""" + + def test_get_step_by_name_found(self): + """Test finding a step by name.""" + project = cast( + ProjectContract, + { + "id": "project-123", + "steps": [ + {"id": "step-1", "name": "Labeling", "order": 0}, + {"id": "step-2", "name": "Review", "order": 1}, + {"id": "step-3", "name": "QA", "order": 2}, + ], + }, + ) + + step = get_step_by_name(project, "Review") + + assert step is not None + assert step.get("id") == "step-2" + assert step.get("name") == "Review" + + def test_get_step_by_name_not_found(self): + """Test step not found returns None.""" + project = cast( + ProjectContract, + { + "id": "project-123", + "steps": [ + {"id": "step-1", "name": "Labeling", "order": 0}, + ], + }, + ) + + step = get_step_by_name(project, "NonExistent") + + assert step is None + + def test_get_step_by_name_no_steps(self): + """Test with project having no steps.""" + project = cast(ProjectContract, {"id": "project-123"}) + + step = get_step_by_name(project, "Labeling") + + assert step is None + + def test_get_ordered_steps(self): + """Test getting steps ordered by their order field.""" + project = cast( + ProjectContract, + { + "id": "project-123", + "steps": [ + {"id": "step-2", "name": "Review", "order": 2}, + {"id": "step-1", "name": "Labeling", "order": 0}, + {"id": "step-3", "name": "QA", "order": 1}, + ], + }, + ) + + ordered = get_ordered_steps(project) + + assert len(ordered) == 3 + assert ordered[0].get("name") == "Labeling" + assert ordered[1].get("name") == "QA" + assert ordered[2].get("name") == "Review" + + def test_get_ordered_steps_empty(self): + """Test getting ordered steps from empty project.""" + project = cast(ProjectContract, {"id": "project-123"}) + + ordered = get_ordered_steps(project) + + assert ordered == [] + + def test_get_ordered_steps_missing_order(self): + """Test ordering steps when some lack order field.""" + project = cast( + ProjectContract, + { + "id": "project-123", + "steps": [ + {"id": "step-2", "name": "Review", "order": 1}, + {"id": "step-1", "name": "Labeling"}, # Missing order + ], + }, + ) + + ordered = get_ordered_steps(project) + + # Step without order defaults to 0 + assert ordered[0].get("name") == "Labeling" + assert ordered[1].get("name") == "Review" diff --git a/tests/unit/domain_v2/test_user.py b/tests/unit/domain_v2/test_user.py new file mode 100644 index 000000000..f9e48785e --- /dev/null +++ b/tests/unit/domain_v2/test_user.py @@ -0,0 +1,414 @@ +"""Unit tests for User domain contracts.""" + +from typing import cast + +from kili.domain_v2.user import ( + UserContract, + UserView, + filter_users_by_activated, + sort_users_by_email, + validate_user, +) + + +class TestUserContract: + """Test suite for UserContract.""" + + def test_validate_user_with_valid_data(self): + """Test validating a valid user contract.""" + user_data = { + "id": "user-123", + "email": "user@example.com", + "name": "John Doe", + "firstname": "John", + "lastname": "Doe", + "activated": True, + "createdAt": "2024-01-01T00:00:00Z", + "organizationId": "org-123", + } + + result = validate_user(user_data) + assert result == user_data + + def test_validate_user_with_partial_data(self): + """Test validating a user with only some fields.""" + user_data = { + "id": "user-123", + "email": "user@example.com", + } + + result = validate_user(user_data) + assert result == user_data + + def test_validate_user_with_organization_role(self): + """Test validating a user with organization role.""" + user_data = { + "id": "user-123", + "email": "admin@example.com", + "organizationRole": { + "id": "role-123", + "role": "ADMIN", + }, + } + + result = validate_user(user_data) + assert result == user_data + org_role = result.get("organizationRole") + assert org_role is not None + assert isinstance(org_role, dict) + assert org_role.get("role") == "ADMIN" + + +class TestUserView: + """Test suite for UserView wrapper.""" + + def test_user_view_basic_properties(self): + """Test basic property access on UserView.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "name": "John Doe", + "firstname": "John", + "lastname": "Doe", + "activated": True, + }, + ) + + view = UserView(user_data) + + assert view.id == "user-123" + assert view.email == "user@example.com" + assert view.name == "John Doe" + assert view.firstname == "John" + assert view.lastname == "Doe" + assert view.activated is True + + def test_user_view_display_name(self): + """Test display name property.""" + # With name + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "name": "John Doe", + }, + ) + view = UserView(user_data) + assert view.display_name == "John Doe" + + # Without name + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + }, + ) + view = UserView(user_data) + assert view.display_name == "user@example.com" + + # Without name and email + user_data = cast(UserContract, {"id": "user-123"}) + view = UserView(user_data) + assert view.display_name == "user-123" + + def test_user_view_full_name(self): + """Test full name property.""" + # With firstname and lastname + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "firstname": "John", + "lastname": "Doe", + }, + ) + view = UserView(user_data) + assert view.full_name == "John Doe" + + # Only firstname + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "firstname": "John", + }, + ) + view = UserView(user_data) + assert view.full_name == "John" + + # Only lastname + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "lastname": "Doe", + }, + ) + view = UserView(user_data) + assert view.full_name == "Doe" + + # No firstname/lastname, fallback to name + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "name": "John Doe", + }, + ) + view = UserView(user_data) + assert view.full_name == "John Doe" + + # No name info, fallback to email + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + }, + ) + view = UserView(user_data) + assert view.full_name == "user@example.com" + + def test_user_view_organization_role(self): + """Test organization role property.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "organizationRole": { + "id": "role-123", + "role": "ADMIN", + }, + }, + ) + + view = UserView(user_data) + + assert view.organization_role is not None + assert isinstance(view.organization_role, dict) + assert view.organization_role.get("role") == "ADMIN" + + def test_user_view_is_admin(self): + """Test is_admin property.""" + # Admin user + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "admin@example.com", + "organizationRole": {"id": "role-123", "role": "ADMIN"}, + }, + ) + view = UserView(user_data) + assert view.is_admin is True + + # Non-admin user + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "organizationRole": {"id": "role-123", "role": "USER"}, + }, + ) + view = UserView(user_data) + assert view.is_admin is False + + # User without role + user_data = cast(UserContract, {"id": "user-123", "email": "user@example.com"}) + view = UserView(user_data) + assert view.is_admin is False + + def test_user_view_organization_id(self): + """Test organization ID property.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "organizationId": "org-123", + }, + ) + + view = UserView(user_data) + assert view.organization_id == "org-123" + + def test_user_view_phone(self): + """Test phone property.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "phone": "+1234567890", + }, + ) + + view = UserView(user_data) + assert view.phone == "+1234567890" + + # Without phone + user_data = cast(UserContract, {"id": "user-123", "email": "user@example.com"}) + view = UserView(user_data) + assert view.phone is None + + def test_user_view_timestamps(self): + """Test timestamp properties.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "createdAt": "2024-01-01T00:00:00Z", + "updatedAt": "2024-01-15T10:30:00Z", + "lastSeenAt": "2024-01-20T14:45:00Z", + }, + ) + + view = UserView(user_data) + + assert view.created_at == "2024-01-01T00:00:00Z" + assert view.updated_at == "2024-01-15T10:30:00Z" + assert view.last_seen_at == "2024-01-20T14:45:00Z" + + def test_user_view_to_dict(self): + """Test converting view back to dictionary.""" + user_data = cast( + UserContract, + { + "id": "user-123", + "email": "user@example.com", + "name": "John Doe", + }, + ) + + view = UserView(user_data) + result = view.to_dict() + + assert result == user_data + assert result is user_data + + def test_user_view_missing_fields(self): + """Test accessing missing fields returns appropriate defaults.""" + user_data = cast(UserContract, {"id": "user-123"}) + view = UserView(user_data) + + assert view.email == "" + assert view.name == "" + assert view.firstname == "" + assert view.lastname == "" + assert view.activated is False + assert view.organization_id == "" + assert view.organization_role is None + assert view.phone is None + assert view.created_at is None + assert view.updated_at is None + assert view.last_seen_at is None + + +class TestUserHelpers: + """Test suite for user helper functions.""" + + def test_sort_users_by_email_ascending(self): + """Test sorting users by email in ascending order.""" + users = [ + cast(UserContract, {"id": "user-3", "email": "charlie@example.com"}), + cast(UserContract, {"id": "user-1", "email": "alice@example.com"}), + cast(UserContract, {"id": "user-2", "email": "bob@example.com"}), + ] + + sorted_users = sort_users_by_email(users, reverse=False) + + assert sorted_users[0].get("email") == "alice@example.com" + assert sorted_users[1].get("email") == "bob@example.com" + assert sorted_users[2].get("email") == "charlie@example.com" + + def test_sort_users_by_email_descending(self): + """Test sorting users by email in descending order.""" + users = [ + cast(UserContract, {"id": "user-1", "email": "alice@example.com"}), + cast(UserContract, {"id": "user-3", "email": "charlie@example.com"}), + cast(UserContract, {"id": "user-2", "email": "bob@example.com"}), + ] + + sorted_users = sort_users_by_email(users, reverse=True) + + assert sorted_users[0].get("email") == "charlie@example.com" + assert sorted_users[1].get("email") == "bob@example.com" + assert sorted_users[2].get("email") == "alice@example.com" + + def test_sort_users_with_missing_email(self): + """Test sorting users when some lack email.""" + users = [ + cast(UserContract, {"id": "user-2", "email": "bob@example.com"}), + cast(UserContract, {"id": "user-no-email"}), + cast(UserContract, {"id": "user-1", "email": "alice@example.com"}), + ] + + sorted_users = sort_users_by_email(users) + + # User without email should come first (empty string sorts first) + assert sorted_users[0].get("id") == "user-no-email" + + def test_filter_users_by_activated_true(self): + """Test filtering for activated users.""" + users = [ + cast(UserContract, {"id": "user-1", "email": "user1@example.com", "activated": True}), + cast(UserContract, {"id": "user-2", "email": "user2@example.com", "activated": False}), + cast(UserContract, {"id": "user-3", "email": "user3@example.com", "activated": True}), + cast(UserContract, {"id": "user-4", "email": "user4@example.com", "activated": False}), + ] + + filtered = filter_users_by_activated(users, activated=True) + + assert len(filtered) == 2 + assert filtered[0].get("id") == "user-1" + assert filtered[1].get("id") == "user-3" + + def test_filter_users_by_activated_false(self): + """Test filtering for deactivated users.""" + users = [ + cast(UserContract, {"id": "user-1", "email": "user1@example.com", "activated": True}), + cast(UserContract, {"id": "user-2", "email": "user2@example.com", "activated": False}), + cast(UserContract, {"id": "user-3", "email": "user3@example.com", "activated": True}), + cast(UserContract, {"id": "user-4", "email": "user4@example.com", "activated": False}), + ] + + filtered = filter_users_by_activated(users, activated=False) + + assert len(filtered) == 2 + assert filtered[0].get("id") == "user-2" + assert filtered[1].get("id") == "user-4" + + def test_filter_users_by_activated_no_matches(self): + """Test filtering when no users match.""" + users = [ + cast(UserContract, {"id": "user-1", "email": "user1@example.com", "activated": True}), + cast(UserContract, {"id": "user-2", "email": "user2@example.com", "activated": True}), + ] + + filtered = filter_users_by_activated(users, activated=False) + + assert len(filtered) == 0 + + def test_filter_users_with_missing_activated(self): + """Test filtering users when some lack activated field.""" + users = [ + cast(UserContract, {"id": "user-1", "email": "user1@example.com", "activated": True}), + cast(UserContract, {"id": "user-2", "email": "user2@example.com"}), # Missing activated + cast(UserContract, {"id": "user-3", "email": "user3@example.com", "activated": False}), + ] + + # Filter for activated=True (user without field won't match) + filtered = filter_users_by_activated(users, activated=True) + assert len(filtered) == 1 + assert filtered[0].get("id") == "user-1" diff --git a/tests/unit/use_cases_v2/__init__.py b/tests/unit/use_cases_v2/__init__.py new file mode 100644 index 000000000..c13c5b061 --- /dev/null +++ b/tests/unit/use_cases_v2/__init__.py @@ -0,0 +1 @@ +"""Unit tests for use_cases_v2 module.""" diff --git a/tests/unit/use_cases_v2/test_interfaces.py b/tests/unit/use_cases_v2/test_interfaces.py new file mode 100644 index 000000000..25dbfa635 --- /dev/null +++ b/tests/unit/use_cases_v2/test_interfaces.py @@ -0,0 +1,736 @@ +"""Tests for repository interface definitions. + +This module tests that repository interfaces are properly defined and +that mock implementations can comply with the Protocol contracts. +""" + +from typing import Generator, List, Optional + +from kili.domain_v2.asset import AssetContract +from kili.domain_v2.label import LabelContract +from kili.domain_v2.project import ProjectContract +from kili.domain_v2.user import UserContract +from kili.use_cases_v2.interfaces import ( + IAssetRepository, + ILabelRepository, + IProjectRepository, + IUserRepository, + PaginationParams, +) + +# Mock Asset Repository Implementation + + +class MockAssetRepository: + """Mock implementation of IAssetRepository for testing.""" + + def __init__(self): + """Initialize with empty asset store.""" + self._assets: dict[str, AssetContract] = {} + self._next_id = 1 + + def get_by_id( + self, + asset_id: str, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[AssetContract]: + """Get asset by ID.""" + return self._assets.get(asset_id) + + def get_by_external_id( + self, + external_id: str, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[AssetContract]: + """Get asset by external ID.""" + for asset in self._assets.values(): + if asset.get("externalId") == external_id: + return asset + return None + + def list( + self, + project_id: str, + fields: Optional[List[str]] = None, + status_in: Optional[List[str]] = None, + external_id_in: Optional[List[str]] = None, + asset_id_in: Optional[List[str]] = None, + metadata_where: Optional[dict] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[AssetContract, None, None]: + """List assets.""" + for asset in self._assets.values(): + if status_in and asset.get("status") not in status_in: + continue + if external_id_in and asset.get("externalId") not in external_id_in: + continue + if asset_id_in and asset.get("id") not in asset_id_in: + continue + yield asset + + def count( + self, + project_id: str, + status_in: Optional[List[str]] = None, + external_id_in: Optional[List[str]] = None, + metadata_where: Optional[dict] = None, + ) -> int: + """Count assets.""" + return len(list(self.list(project_id, status_in=status_in))) + + def create( + self, + project_id: str, + content: str, + external_id: str, + json_metadata: Optional[dict] = None, + ) -> AssetContract: + """Create an asset.""" + asset_id = str(self._next_id) + self._next_id += 1 + asset: AssetContract = { + "id": asset_id, + "externalId": external_id, + "content": content, + "jsonMetadata": json_metadata, + "status": "TODO", + "labels": [], + "isHoneypot": False, + "skipped": False, + "createdAt": "2024-01-01T00:00:00Z", + } + self._assets[asset_id] = asset + return asset + + def update_metadata( + self, + asset_id: str, + json_metadata: dict, + ) -> AssetContract: + """Update asset metadata.""" + asset = self._assets[asset_id] + asset["jsonMetadata"] = json_metadata + return asset + + def delete( + self, + asset_ids: List[str], + ) -> int: + """Delete assets.""" + count = 0 + for asset_id in asset_ids: + if asset_id in self._assets: + del self._assets[asset_id] + count += 1 + return count + + +# Mock Label Repository Implementation + + +class MockLabelRepository: + """Mock implementation of ILabelRepository for testing.""" + + def __init__(self): + """Initialize with empty label store.""" + self._labels: dict[str, LabelContract] = {} + self._next_id = 1 + + def get_by_id( + self, + label_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[LabelContract]: + """Get label by ID.""" + return self._labels.get(label_id) + + def list( + self, + asset_id: Optional[str] = None, + project_id: Optional[str] = None, + fields: Optional[List[str]] = None, + label_type_in: Optional[List[str]] = None, + author_in: Optional[List[str]] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[LabelContract, None, None]: + """List labels.""" + for label in self._labels.values(): + if label_type_in and label.get("labelType") not in label_type_in: + continue + if author_in and label.get("author", {}).get("id") not in author_in: + continue + yield label + + def count( + self, + asset_id: Optional[str] = None, + project_id: Optional[str] = None, + label_type_in: Optional[List[str]] = None, + author_in: Optional[List[str]] = None, + ) -> int: + """Count labels.""" + return len(list(self.list(label_type_in=label_type_in, author_in=author_in))) + + def create( + self, + asset_id: str, + json_response: dict, + label_type: str = "DEFAULT", + seconds_to_label: Optional[int] = None, + ) -> LabelContract: + """Create a label.""" + label_id = str(self._next_id) + self._next_id += 1 + label: LabelContract = { + "id": label_id, + "author": {"id": "user1", "email": "user@example.com"}, + "jsonResponse": json_response, + "createdAt": "2024-01-01T00:00:00Z", + "labelType": label_type, # type: ignore + "isLatestLabelForUser": True, + "isLatestDefaultLabelForUser": True, + "skipped": False, + } + self._labels[label_id] = label + return label + + def update( + self, + label_id: str, + json_response: dict, + ) -> LabelContract: + """Update a label.""" + label = self._labels[label_id] + label["jsonResponse"] = json_response + return label + + def delete( + self, + label_ids: List[str], + ) -> int: + """Delete labels.""" + count = 0 + for label_id in label_ids: + if label_id in self._labels: + del self._labels[label_id] + count += 1 + return count + + +# Mock Project Repository Implementation + + +class MockProjectRepository: + """Mock implementation of IProjectRepository for testing.""" + + def __init__(self): + """Initialize with empty project store.""" + self._projects: dict[str, ProjectContract] = {} + self._next_id = 1 + + def get_by_id( + self, + project_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[ProjectContract]: + """Get project by ID.""" + return self._projects.get(project_id) + + def list( + self, + fields: Optional[List[str]] = None, + archived: Optional[bool] = None, + starred: Optional[bool] = None, + input_type_in: Optional[List[str]] = None, + created_at_gte: Optional[str] = None, + created_at_lte: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[ProjectContract, None, None]: + """List projects.""" + for project in self._projects.values(): + if archived is not None and project.get("archived") != archived: + continue + if starred is not None and project.get("starred") != starred: + continue + if input_type_in and project.get("inputType") not in input_type_in: + continue + yield project + + def count( + self, + archived: Optional[bool] = None, + starred: Optional[bool] = None, + input_type_in: Optional[List[str]] = None, + ) -> int: + """Count projects.""" + return len(list(self.list(archived=archived, starred=starred))) + + def create( + self, + title: str, + description: str, + input_type: str, + json_interface: dict, + ) -> ProjectContract: + """Create a project.""" + project_id = str(self._next_id) + self._next_id += 1 + project: ProjectContract = { + "id": project_id, + "title": title, + "description": description, + "inputType": input_type, # type: ignore + "jsonInterface": json_interface, + "workflowVersion": "V2", + "numberOfAssets": 0, + "archived": False, + "starred": False, + "createdAt": "2024-01-01T00:00:00Z", + "steps": [], + "roles": [], + "complianceTags": [], + "useHoneypot": False, + "readPermissionsForAssetsAndLabels": True, + "shouldRelaunchKpiComputation": False, + } + self._projects[project_id] = project + return project + + def update( + self, + project_id: str, + title: Optional[str] = None, + description: Optional[str] = None, + json_interface: Optional[dict] = None, + ) -> ProjectContract: + """Update a project.""" + project = self._projects[project_id] + if title is not None: + project["title"] = title + if description is not None: + project["description"] = description + if json_interface is not None: + project["jsonInterface"] = json_interface + return project + + def archive( + self, + project_id: str, + ) -> ProjectContract: + """Archive a project.""" + project = self._projects[project_id] + project["archived"] = True + return project + + def delete( + self, + project_ids: List[str], + ) -> int: + """Delete projects.""" + count = 0 + for project_id in project_ids: + if project_id in self._projects: + del self._projects[project_id] + count += 1 + return count + + +# Mock User Repository Implementation + + +class MockUserRepository: + """Mock implementation of IUserRepository for testing.""" + + def __init__(self): + """Initialize with empty user store.""" + self._users: dict[str, UserContract] = {} + self._next_id = 1 + + def get_by_id( + self, + user_id: str, + fields: Optional[List[str]] = None, + ) -> Optional[UserContract]: + """Get user by ID.""" + return self._users.get(user_id) + + def get_by_email( + self, + email: str, + fields: Optional[List[str]] = None, + ) -> Optional[UserContract]: + """Get user by email.""" + for user in self._users.values(): + if user.get("email") == email: + return user + return None + + def list( + self, + organization_id: str, + fields: Optional[List[str]] = None, + activated: Optional[bool] = None, + email_contains: Optional[str] = None, + pagination: Optional[PaginationParams] = None, + ) -> Generator[UserContract, None, None]: + """List users.""" + for user in self._users.values(): + if activated is not None and user.get("activated") != activated: + continue + if email_contains and email_contains not in user.get("email", ""): + continue + yield user + + def count( + self, + organization_id: str, + activated: Optional[bool] = None, + ) -> int: + """Count users.""" + return len(list(self.list(organization_id, activated=activated))) + + def create( + self, + organization_id: str, + email: str, + firstname: str, + lastname: str, + role: str = "USER", + ) -> UserContract: + """Create a user.""" + user_id = str(self._next_id) + self._next_id += 1 + user: UserContract = { + "id": user_id, + "email": email, + "name": f"{firstname} {lastname}", + "firstname": firstname, + "lastname": lastname, + "activated": True, + "organizationId": organization_id, + "organizationRole": {"id": "role1", "role": role}, # type: ignore + "createdAt": "2024-01-01T00:00:00Z", + "hubspotSubscriptionStatus": "SUBSCRIBED", + "apiKey": "key123", + } + self._users[user_id] = user + return user + + def update( + self, + user_id: str, + firstname: Optional[str] = None, + lastname: Optional[str] = None, + activated: Optional[bool] = None, + ) -> UserContract: + """Update a user.""" + user = self._users[user_id] + if firstname is not None: + user["firstname"] = firstname + if lastname is not None: + user["lastname"] = lastname + if activated is not None: + user["activated"] = activated + return user + + +# Protocol Compliance Tests + + +def test_asset_repository_protocol_compliance(): + """Test that MockAssetRepository complies with IAssetRepository protocol.""" + repo: IAssetRepository = MockAssetRepository() + + # Test create + asset = repo.create( + project_id="proj1", + content="https://example.com/image.jpg", + external_id="asset-1", + json_metadata={"key": "value"}, + ) + asset_id = asset.get("id") + assert asset_id is not None + assert asset.get("externalId") == "asset-1" + assert asset.get("content") == "https://example.com/image.jpg" + + # Test get_by_id + retrieved = repo.get_by_id(asset_id, "proj1") + assert retrieved is not None + assert retrieved.get("id") == asset_id + + # Test get_by_external_id + retrieved_by_ext = repo.get_by_external_id("asset-1", "proj1") + assert retrieved_by_ext is not None + assert retrieved_by_ext.get("externalId") == "asset-1" + + # Test list + assets = list(repo.list("proj1")) + assert len(assets) == 1 + + # Test count + count = repo.count("proj1") + assert count == 1 + + # Test update_metadata + updated = repo.update_metadata(asset_id, {"new_key": "new_value"}) + assert updated.get("jsonMetadata") == {"new_key": "new_value"} + + # Test delete + deleted = repo.delete([asset_id]) + assert deleted == 1 + assert repo.count("proj1") == 0 + + +def test_label_repository_protocol_compliance(): + """Test that MockLabelRepository complies with ILabelRepository protocol.""" + repo: ILabelRepository = MockLabelRepository() + + # Test create + label = repo.create( + asset_id="asset1", + json_response={"annotation": "value"}, + label_type="DEFAULT", + ) + label_id = label.get("id") + assert label_id is not None + assert label.get("jsonResponse") == {"annotation": "value"} + + # Test get_by_id + retrieved = repo.get_by_id(label_id) + assert retrieved is not None + assert retrieved.get("id") == label_id + + # Test list + labels = list(repo.list()) + assert len(labels) == 1 + + # Test count + count = repo.count() + assert count == 1 + + # Test update + updated = repo.update(label_id, {"updated": "annotation"}) + assert updated.get("jsonResponse") == {"updated": "annotation"} + + # Test delete + deleted = repo.delete([label_id]) + assert deleted == 1 + assert repo.count() == 0 + + +def test_project_repository_protocol_compliance(): + """Test that MockProjectRepository complies with IProjectRepository protocol.""" + repo: IProjectRepository = MockProjectRepository() + + # Test create + project = repo.create( + title="Test Project", + description="A test project", + input_type="IMAGE", + json_interface={"jobs": []}, + ) + project_id = project.get("id") + assert project_id is not None + assert project.get("title") == "Test Project" + + # Test get_by_id + retrieved = repo.get_by_id(project_id) + assert retrieved is not None + assert retrieved.get("id") == project_id + + # Test list + projects = list(repo.list()) + assert len(projects) == 1 + + # Test count + count = repo.count() + assert count == 1 + + # Test update + updated = repo.update(project_id, title="Updated Title") + assert updated.get("title") == "Updated Title" + + # Test archive + archived = repo.archive(project_id) + assert archived.get("archived") is True + + # Test delete + deleted = repo.delete([project_id]) + assert deleted == 1 + assert repo.count(archived=True) == 0 + + +def test_user_repository_protocol_compliance(): + """Test that MockUserRepository complies with IUserRepository protocol.""" + repo: IUserRepository = MockUserRepository() + + # Test create + user = repo.create( + organization_id="org1", + email="test@example.com", + firstname="John", + lastname="Doe", + role="USER", + ) + user_id = user.get("id") + assert user_id is not None + assert user.get("email") == "test@example.com" + + # Test get_by_id + retrieved = repo.get_by_id(user_id) + assert retrieved is not None + assert retrieved.get("id") == user_id + + # Test get_by_email + retrieved_by_email = repo.get_by_email("test@example.com") + assert retrieved_by_email is not None + assert retrieved_by_email.get("email") == "test@example.com" + + # Test list + users = list(repo.list("org1")) + assert len(users) == 1 + + # Test count + count = repo.count("org1") + assert count == 1 + + # Test update + updated = repo.update(user_id, firstname="Jane") + assert updated.get("firstname") == "Jane" + + +def test_pagination_params(): + """Test PaginationParams initialization.""" + # Default params + params = PaginationParams() + assert params.skip == 0 + assert params.first is None + assert params.batch_size == 100 + + # Custom params + params = PaginationParams(skip=10, first=50, batch_size=25) + assert params.skip == 10 + assert params.first == 50 + assert params.batch_size == 25 + + +def test_asset_repository_filtering(): + """Test asset repository filtering functionality.""" + repo: IAssetRepository = MockAssetRepository() + + # Create multiple assets with different statuses + asset1 = repo.create("proj1", "content1", "asset-1", {"key": "value1"}) + asset2 = repo.create("proj1", "content2", "asset-2", {"key": "value2"}) + asset3 = repo.create("proj1", "content3", "asset-3", {"key": "value3"}) + + # Manually set status to test filtering (since create sets all to TODO) + asset2_id = asset2.get("id") + assert asset2_id is not None + repo._assets[asset2_id]["status"] = "LABELED" + + # Filter by status + todo_assets = list(repo.list("proj1", status_in=["TODO"])) + assert len(todo_assets) == 2 + + # Filter by external_id + filtered = list(repo.list("proj1", external_id_in=["asset-1", "asset-2"])) + assert len(filtered) == 2 + + +def test_label_repository_filtering(): + """Test label repository filtering functionality.""" + repo: ILabelRepository = MockLabelRepository() + + # Create multiple labels with different types + repo.create("asset1", {"data": 1}, "DEFAULT") + repo.create("asset1", {"data": 2}, "REVIEW") + repo.create("asset2", {"data": 3}, "DEFAULT") + + # Filter by label type + default_labels = list(repo.list(label_type_in=["DEFAULT"])) + assert len(default_labels) == 2 + + review_labels = list(repo.list(label_type_in=["REVIEW"])) + assert len(review_labels) == 1 + + +def test_project_repository_filtering(): + """Test project repository filtering functionality.""" + repo: IProjectRepository = MockProjectRepository() + + # Create multiple projects + proj1 = repo.create("Project 1", "Desc 1", "IMAGE", {}) + repo.create("Project 2", "Desc 2", "TEXT", {}) + proj1_id = proj1.get("id") + assert proj1_id is not None + repo.archive(proj1_id) + + # Filter by archived status + archived_projects = list(repo.list(archived=True)) + assert len(archived_projects) == 1 + + active_projects = list(repo.list(archived=False)) + assert len(active_projects) == 1 + + # Filter by input type + image_projects = list(repo.list(input_type_in=["IMAGE"])) + assert len(image_projects) == 1 + + +def test_user_repository_filtering(): + """Test user repository filtering functionality.""" + repo: IUserRepository = MockUserRepository() + + # Create multiple users + user1 = repo.create("org1", "alice@example.com", "Alice", "Smith") + repo.create("org1", "bob@example.com", "Bob", "Jones") + user1_id = user1.get("id") + assert user1_id is not None + repo.update(user1_id, activated=False) + + # Filter by activated status + active_users = list(repo.list("org1", activated=True)) + assert len(active_users) == 1 + + inactive_users = list(repo.list("org1", activated=False)) + assert len(inactive_users) == 1 + + # Filter by email substring + alice_users = list(repo.list("org1", email_contains="alice")) + assert len(alice_users) == 1 + + +def test_repository_returns_correct_types(): + """Test that repositories return correct TypedDict types.""" + asset_repo: IAssetRepository = MockAssetRepository() + label_repo: ILabelRepository = MockLabelRepository() + project_repo: IProjectRepository = MockProjectRepository() + user_repo: IUserRepository = MockUserRepository() + + # Create entities + asset = asset_repo.create("proj1", "content", "ext-1") + label = label_repo.create("asset1", {"data": "test"}) + project = project_repo.create("Title", "Desc", "IMAGE", {}) + user = user_repo.create("org1", "user@example.com", "First", "Last") + + # Verify types + assert isinstance(asset, dict) + assert isinstance(label, dict) + assert isinstance(project, dict) + assert isinstance(user, dict) + + # Verify required fields + assert "id" in asset + assert "externalId" in asset + assert "id" in label + assert "jsonResponse" in label + assert "id" in project + assert "title" in project + assert "id" in user + assert "email" in user diff --git a/tests_v2/__init__.py b/tests_v2/__init__.py new file mode 100644 index 000000000..fccdfc7f9 --- /dev/null +++ b/tests_v2/__init__.py @@ -0,0 +1,121 @@ +"""Integration test utilities for domain_v2 View objects. + +This module provides shared utilities and helper functions for testing +the domain_v2 View objects (AssetView, LabelView, ProjectView, UserView) +against the real Kili API. + +The View objects wrap dictionaries returned from the Kili API and provide +ergonomic property access while maintaining backward compatibility with +dictionary representations. + +Test Configuration: + API_KEY: + ENDPOINT: http://localhost:4001/api/label/v2/graphql + +Example: + >>> from tests_v2 import assert_is_view, assert_view_has_dict_compatibility + >>> from kili.domain_v2.asset import AssetView + >>> + >>> # Test that object is a View instance + >>> assert_is_view(obj, AssetView) + >>> + >>> # Test View dictionary compatibility + >>> assert_view_has_dict_compatibility(asset_view) +""" + +from typing import Any, Type + + +def assert_is_view(obj: Any, view_class: Type) -> None: + """Assert that an object is an instance of a specific View class. + + This function verifies that: + - The object is an instance of the expected View class + - The object has the required _data attribute + - The object has the to_dict() method + + Args: + obj: The object to check + view_class: The expected View class (e.g., AssetView, LabelView) + + Raises: + AssertionError: If the object is not a valid View instance + + Example: + >>> from kili.domain_v2.asset import AssetView + >>> assert_is_view(asset_obj, AssetView) + """ + assert isinstance( + obj, view_class + ), f"Expected instance of {view_class.__name__}, got {type(obj).__name__}" + + # Verify View has required structure + assert hasattr(obj, "_data"), f"{view_class.__name__} instance missing _data attribute" + assert hasattr(obj, "to_dict"), f"{view_class.__name__} instance missing to_dict() method" + + # Verify _data is a dictionary + assert isinstance( + obj._data, + dict, # pylint: disable=protected-access + ), f"{view_class.__name__}._data should be a dictionary" + + +def assert_view_has_dict_compatibility(view: Any) -> None: + """Assert that a View object maintains dictionary compatibility. + + This function verifies that: + - The View has a to_dict() method + - The to_dict() method returns a dictionary + - The returned dictionary is the same as the internal _data + + Args: + view: The View object to check + + Raises: + AssertionError: If the View doesn't have proper dictionary compatibility + + Example: + >>> assert_view_has_dict_compatibility(asset_view) + """ + # Verify to_dict() exists and returns a dictionary + assert hasattr(view, "to_dict"), "View object missing to_dict() method" + + dict_repr = view.to_dict() + assert isinstance(dict_repr, dict), f"to_dict() should return a dict, got {type(dict_repr)}" + + # Verify to_dict() returns the same reference as _data (zero-copy) + assert hasattr(view, "_data"), "View object missing _data attribute" + assert ( + dict_repr is view._data # pylint: disable=protected-access + ), "to_dict() should return the same reference as _data (not a copy)" + + +def assert_view_property_access(view: Any, property_name: str, expected_value: Any = None) -> None: + """Assert that a View property is accessible and optionally matches an expected value. + + Args: + view: The View object to check + property_name: Name of the property to access + expected_value: Optional expected value for the property + + Raises: + AssertionError: If the property is not accessible or doesn't match expected value + + Example: + >>> assert_view_property_access(asset_view, "id") + >>> assert_view_property_access(asset_view, "external_id", "asset-1") + """ + assert hasattr(view, property_name), f"View object missing property '{property_name}'" + + # Try to access the property + try: + actual_value = getattr(view, property_name) + except Exception as exc: # pylint: disable=broad-except + raise AssertionError(f"Error accessing property '{property_name}': {exc}") from exc + + # If expected value provided, verify it matches + if expected_value is not None: + assert actual_value == expected_value, ( + f"Property '{property_name}' has value {actual_value!r}, " + f"expected {expected_value!r}" + ) diff --git a/tests_v2/conftest.py b/tests_v2/conftest.py new file mode 100644 index 000000000..e8f6c60b5 --- /dev/null +++ b/tests_v2/conftest.py @@ -0,0 +1,121 @@ +"""Pytest configuration and shared fixtures for domain_v2 integration tests. + +This module provides fixtures for testing domain_v2 View objects against +the real Kili API using test credentials. + +Test Configuration: + API_KEY: + ENDPOINT: http://localhost:4001/api/label/v2/graphql +""" + +import os +from typing import Generator + +import pytest + +from kili.client_domain import Kili + +# Test configuration constants +TEST_API_KEY = "" +TEST_ENDPOINT = "http://localhost:4001/api/label/v2/graphql" + + +@pytest.fixture(scope="session") +def api_key() -> str: + """Provide the API key for test authentication. + + Returns: + Test API key for integration tests + + Example: + >>> def test_with_api_key(api_key): + ... assert api_key == TEST_API_KEY + """ + return TEST_API_KEY + + +@pytest.fixture(scope="session") +def api_endpoint() -> str: + """Provide the API endpoint for test server. + + Returns: + Test API endpoint URL + + Example: + >>> def test_with_endpoint(api_endpoint): + ... assert api_endpoint == TEST_ENDPOINT + """ + return TEST_ENDPOINT + + +@pytest.fixture(scope="session") +def kili_client(api_key: str, api_endpoint: str) -> Generator[Kili, None, None]: + """Provide a configured Kili client for integration tests. + + This fixture creates a Kili client instance using the test credentials + and yields it for use in tests. The client uses the domain API with + namespace organization. + + Args: + api_key: Test API key (from api_key fixture) + api_endpoint: Test endpoint URL (from api_endpoint fixture) + + Yields: + Configured Kili client instance + + Example: + >>> def test_assets(kili_client): + ... assets = kili_client.assets.list(first=10) + ... assert isinstance(assets, list) + """ + # Temporarily override environment variables if needed + original_key = os.environ.get("KILI_API_KEY") + original_endpoint = os.environ.get("KILI_API_ENDPOINT") + + try: + # Create client with test credentials + client = Kili(api_key=api_key, api_endpoint=api_endpoint) + yield client + finally: + # Restore original environment variables + if original_key is not None: + os.environ["KILI_API_KEY"] = original_key + elif "KILI_API_KEY" in os.environ: + del os.environ["KILI_API_KEY"] + + if original_endpoint is not None: + os.environ["KILI_API_ENDPOINT"] = original_endpoint + elif "KILI_API_ENDPOINT" in os.environ: + del os.environ["KILI_API_ENDPOINT"] + + +@pytest.fixture() +def skip_if_no_data(kili_client: Kili, entity_type: str): + """Skip test if no test data exists for the entity type. + + This fixture can be parametrized to check for specific entity types + (projects, assets, labels, users) and skip tests gracefully when no + data is available. + + Args: + kili_client: Configured Kili client + entity_type: Type of entity to check ('projects', 'assets', 'labels', 'users') + + Raises: + pytest.skip: If no data exists for the entity type + + Example: + >>> @pytest.mark.parametrize("entity_type", ["projects"]) + >>> def test_projects(kili_client, skip_if_no_data): + ... # Test will be skipped if no projects exist + ... projects = kili_client.projects.list(first=1) + """ + # This is a marker fixture - actual implementation in individual tests + # Each test can choose to use this pattern to skip when no data exists + + +# Configure pytest to show more detailed output +def pytest_configure(config): + """Configure pytest for integration tests.""" + config.addinivalue_line("markers", "integration: mark test as integration test") + config.addinivalue_line("markers", "requires_data: mark test as requiring existing test data") diff --git a/tests_v2/test_assets_view.py b/tests_v2/test_assets_view.py new file mode 100644 index 000000000..c8e003fcd --- /dev/null +++ b/tests_v2/test_assets_view.py @@ -0,0 +1,323 @@ +"""Integration tests for AssetView objects returned by the assets namespace. + +This test file validates that the assets.list() method correctly returns +AssetView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns AssetView objects in all modes (list, generator, DataFrame) + - Test AssetView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Ensure DataFrame mode remains unchanged +""" + +import pytest + +from kili.domain_v2.asset import AssetView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_asset_views(kili_client): + """Test that assets.list() in list mode returns AssetView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get assets in list mode + assets = kili_client.assets.list(project_id=project_id, first=5, as_generator=False) + + # Verify we get a list + assert isinstance(assets, list), "assets.list() with as_generator=False should return a list" + + # Skip if no assets + if not assets: + pytest.skip(f"No assets available in project {project_id}") + + # Verify each item is an AssetView + for asset in assets: + assert_is_view(asset, AssetView) + + # Verify we can access basic properties + assert hasattr(asset, "id") + assert hasattr(asset, "external_id") + assert hasattr(asset, "display_name") + + +@pytest.mark.integration() +def test_list_generator_returns_asset_views(kili_client): + """Test that assets.list() in generator mode returns AssetView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get assets in generator mode + assets_gen = kili_client.assets.list(project_id=project_id, first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + assets_from_gen = [] + for i, asset in enumerate(assets_gen): + if i >= 5: + break + assets_from_gen.append(asset) + + # Skip if no assets + if not assets_from_gen: + pytest.skip(f"No assets available in project {project_id}") + + # Verify each yielded item is an AssetView + for asset in assets_from_gen: + assert_is_view(asset, AssetView) + + # Verify we can access basic properties + assert hasattr(asset, "id") + assert hasattr(asset, "external_id") + + +@pytest.mark.integration() +def test_asset_view_properties(kili_client): + """Test that AssetView provides access to all expected properties.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get first asset + assets = kili_client.assets.list(project_id=project_id, first=1, as_generator=False) + + if not assets: + pytest.skip(f"No assets available in project {project_id}") + + asset = assets[0] + + # Verify AssetView type + assert_is_view(asset, AssetView) + + # Test core properties exist and are accessible + assert_view_property_access(asset, "id") + assert_view_property_access(asset, "external_id") + assert_view_property_access(asset, "content") + assert_view_property_access(asset, "display_name") + + # Test that id is not empty + assert asset.id, "Asset id should not be empty" + + # Test display_name logic (should be external_id if available, else id) + if asset.external_id: + assert asset.display_name == asset.external_id + else: + assert asset.display_name == asset.id + + # Test optional properties + assert_view_property_access(asset, "metadata") + assert_view_property_access(asset, "labels") + assert_view_property_access(asset, "latest_label") + assert_view_property_access(asset, "status") + assert_view_property_access(asset, "current_step") + assert_view_property_access(asset, "is_honeypot") + assert_view_property_access(asset, "skipped") + assert_view_property_access(asset, "created_at") + + # Test computed properties + assert_view_property_access(asset, "has_labels") + assert_view_property_access(asset, "label_count") + + # Verify labels is a list + assert isinstance(asset.labels, list), "labels property should return a list" + + # Verify label_count matches labels length + assert asset.label_count == len(asset.labels), "label_count should match length of labels" + + # Verify has_labels is consistent with label_count + assert asset.has_labels == (asset.label_count > 0), "has_labels should match label_count > 0" + + +@pytest.mark.integration() +def test_asset_view_dict_compatibility(kili_client): + """Test that AssetView maintains backward compatibility via to_dict().""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get first asset + assets = kili_client.assets.list(project_id=project_id, first=1, as_generator=False) + + if not assets: + pytest.skip(f"No assets available in project {project_id}") + + asset = assets[0] + + # Verify AssetView type + assert_is_view(asset, AssetView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(asset) + + # Get dictionary representation + asset_dict = asset.to_dict() + + # Verify it's a dictionary + assert isinstance(asset_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in asset_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "externalId" in asset_dict: + assert ( + asset_dict["externalId"] == asset.external_id + ), "Dictionary externalId should match property" + + if "content" in asset_dict: + assert asset_dict["content"] == asset.content, "Dictionary content should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert asset_dict is asset._data, "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_asset_view_with_dataframe(kili_client): + """Test that DataFrame mode still returns DataFrame (unchanged behavior).""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Check if pandas is available + try: + import pandas as pd + except ImportError: + pytest.skip("pandas not available, skipping DataFrame test") + + # Get assets in DataFrame mode + assets_df = kili_client.assets.list( + project_id=project_id, first=5, as_generator=False, format="pandas" + ) + + # Verify we get a DataFrame + assert isinstance( + assets_df, pd.DataFrame + ), "assets.list() with format='pandas' should return DataFrame" + + # Verify DataFrame has expected structure + if not assets_df.empty: + # DataFrame should have 'id' column + assert "id" in assets_df.columns, "DataFrame should have 'id' column" + + +@pytest.mark.integration() +def test_asset_view_filtering(kili_client): + """Test that AssetView objects work correctly with filtering.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get all assets + all_assets = kili_client.assets.list(project_id=project_id, first=10, as_generator=False) + + if not all_assets: + pytest.skip(f"No assets available in project {project_id}") + + # Get the first asset's ID + first_asset_id = all_assets[0].id + + # Query for specific asset by ID + filtered_assets = kili_client.assets.list( + project_id=project_id, asset_id=first_asset_id, as_generator=False + ) + + # Verify we got results + assert len(filtered_assets) > 0, "Should get at least one asset with specific asset_id" + + # Verify each result is an AssetView + for asset in filtered_assets: + assert_is_view(asset, AssetView) + + # Verify it's the correct asset + assert asset.id == first_asset_id, "Filtered asset should have the requested ID" + + +@pytest.mark.integration() +def test_asset_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Query with a filter that should return no results + # Using a non-existent asset ID + empty_assets = kili_client.assets.list( + project_id=project_id, asset_id="non-existent-asset-id-12345", as_generator=False + ) + + # Verify we get an empty list + assert isinstance(empty_assets, list), "Should return a list even when no results" + assert len(empty_assets) == 0, "Should return empty list for non-existent asset" + + +@pytest.mark.integration() +def test_asset_view_with_fields_parameter(kili_client): + """Test that AssetView works correctly with custom fields parameter.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Query with specific fields + assets = kili_client.assets.list( + project_id=project_id, first=1, fields=["id", "externalId", "content"], as_generator=False + ) + + if not assets: + pytest.skip(f"No assets available in project {project_id}") + + asset = assets[0] + + # Verify it's still an AssetView + assert_is_view(asset, AssetView) + + # Verify requested fields are accessible + assert_view_property_access(asset, "id") + assert_view_property_access(asset, "external_id") + assert_view_property_access(asset, "content") diff --git a/tests_v2/test_connections_view.py b/tests_v2/test_connections_view.py new file mode 100644 index 000000000..24ddeef4d --- /dev/null +++ b/tests_v2/test_connections_view.py @@ -0,0 +1,356 @@ +"""Integration tests for ConnectionView objects returned by the connections namespace. + +This test file validates that the connections.list() method correctly returns +ConnectionView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns ConnectionView objects in all modes (list, generator) + - Test ConnectionView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Test filtering by project_id and cloud_storage_integration_id + - Verify computed properties (folder_count, display_name) + - Ensure mutation methods still return dicts (unchanged) +""" + +import pytest + +from kili.domain_v2.connection import ConnectionView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_connection_views(kili_client): + """Test that connections.list() in list mode returns ConnectionView objects.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + connections = [] + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, first=5, as_generator=False + ) + if connections: + break + + # Skip if no connections found + if not connections: + pytest.skip("No connections available for testing") + + # Verify we get a list + assert isinstance( + connections, list + ), "connections.list() with as_generator=False should return a list" + + # Verify each item is a ConnectionView + for connection in connections: + assert_is_view(connection, ConnectionView) + + # Verify we can access basic properties + assert hasattr(connection, "id") + assert hasattr(connection, "project_id") + assert hasattr(connection, "number_of_assets") + assert hasattr(connection, "selected_folders") + + +@pytest.mark.integration() +def test_list_generator_returns_connection_views(kili_client): + """Test that connections.list() in generator mode returns ConnectionView objects.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + project_id_with_connections = None + for project in projects: + test_connections = kili_client.connections.list( + project_id=project.id, first=1, as_generator=False + ) + if test_connections: + project_id_with_connections = project.id + break + + if not project_id_with_connections: + pytest.skip("No connections available for testing") + + # Get connections in generator mode + connections_gen = kili_client.connections.list( + project_id=project_id_with_connections, first=5, as_generator=True + ) + + # Take first 5 items from generator (or fewer if less available) + connections_from_gen = [] + for i, connection in enumerate(connections_gen): + if i >= 5: + break + connections_from_gen.append(connection) + + # Skip if no connections + if not connections_from_gen: + pytest.skip("No connections available for testing") + + # Verify each yielded item is a ConnectionView + for connection in connections_from_gen: + assert_is_view(connection, ConnectionView) + + # Verify we can access basic properties + assert hasattr(connection, "id") + assert hasattr(connection, "project_id") + assert hasattr(connection, "number_of_assets") + + +@pytest.mark.integration() +def test_connection_view_properties(kili_client): + """Test that ConnectionView provides access to all expected properties.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + connection = None + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, first=1, as_generator=False + ) + if connections: + connection = connections[0] + break + + if not connection: + pytest.skip("No connections available for testing") + + # Verify ConnectionView type + assert_is_view(connection, ConnectionView) + + # Test core properties exist and are accessible + assert_view_property_access(connection, "id") + assert_view_property_access(connection, "project_id") + assert_view_property_access(connection, "number_of_assets") + assert_view_property_access(connection, "selected_folders") + + # Test that id is not empty + assert connection.id, "Connection id should not be empty" + + # Test that project_id is not empty + assert connection.project_id, "Connection project_id should not be empty" + + # Test computed properties + assert_view_property_access(connection, "folder_count") + assert_view_property_access(connection, "display_name") + + # Test that folder_count is the length of selected_folders + assert connection.folder_count == len( + connection.selected_folders + ), "folder_count should equal the number of selected_folders" + + # Test that number_of_assets is non-negative + assert connection.number_of_assets >= 0, "number_of_assets should be non-negative" + + # Test that selected_folders is a list + assert isinstance(connection.selected_folders, list), "selected_folders should be a list" + + # Test optional properties + assert_view_property_access(connection, "last_checked") + + # Test display_name returns id + assert connection.display_name == connection.id, "display_name should return the connection id" + + +@pytest.mark.integration() +def test_connection_view_dict_compatibility(kili_client): + """Test that ConnectionView maintains backward compatibility via to_dict().""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + connection = None + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, first=1, as_generator=False + ) + if connections: + connection = connections[0] + break + + if not connection: + pytest.skip("No connections available for testing") + + # Verify ConnectionView type + assert_is_view(connection, ConnectionView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(connection) + + # Get dictionary representation + connection_dict = connection.to_dict() + + # Verify it's a dictionary + assert isinstance(connection_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in connection_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "projectId" in connection_dict: + assert ( + connection_dict["projectId"] == connection.project_id + ), "Dictionary projectId should match property" + + if "numberOfAssets" in connection_dict: + assert ( + connection_dict["numberOfAssets"] == connection.number_of_assets + ), "Dictionary numberOfAssets should match property" + + if "selectedFolders" in connection_dict: + assert ( + connection_dict["selectedFolders"] == connection.selected_folders + ), "Dictionary selectedFolders should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert ( + connection_dict is connection._data + ), "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_connection_view_filtering(kili_client): + """Test that ConnectionView objects work correctly with filtering.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project + project_id_with_connections = None + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, first=1, as_generator=False + ) + if connections: + project_id_with_connections = project.id + break + + if not project_id_with_connections: + pytest.skip("No connections available for testing") + + # Test filtering by project_id + connections_by_project = kili_client.connections.list( + project_id=project_id_with_connections, first=10, as_generator=False + ) + + # Verify results are ConnectionView objects + assert len(connections_by_project) > 0, "Should find connections for the project" + for connection in connections_by_project: + assert_is_view(connection, ConnectionView) + # All should have the same project_id + assert ( + connection.project_id == project_id_with_connections + ), "Filtered connections should belong to the specified project" + + +@pytest.mark.integration() +def test_connection_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Get all projects to get a valid project_id + projects = kili_client.projects.list(first=1, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + project_id = projects[0].id + + # Query with a project filter - may or may not have connections + # This tests that empty results return an empty list + connections = kili_client.connections.list(project_id=project_id, as_generator=False) + + # Verify we get a list (even if empty) + assert isinstance(connections, list), "Should return a list even when no results" + + +@pytest.mark.integration() +def test_connection_view_with_fields_parameter(kili_client): + """Test that ConnectionView works correctly with custom fields parameter.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + connection = None + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, + first=1, + fields=["id", "projectId", "numberOfAssets", "selectedFolders", "lastChecked"], + as_generator=False, + ) + if connections: + connection = connections[0] + break + + if not connection: + pytest.skip("No connections available for testing") + + # Verify it's still a ConnectionView + assert_is_view(connection, ConnectionView) + + # Verify requested fields are accessible + assert_view_property_access(connection, "id") + assert_view_property_access(connection, "project_id") + assert_view_property_access(connection, "number_of_assets") + assert_view_property_access(connection, "selected_folders") + assert_view_property_access(connection, "last_checked") + + +@pytest.mark.integration() +def test_connection_view_folder_count(kili_client): + """Test that folder_count property works correctly.""" + # Get all projects to find one with connections + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing connections") + + # Try to get connections for each project until we find some + connection = None + for project in projects: + connections = kili_client.connections.list( + project_id=project.id, first=1, as_generator=False + ) + if connections: + connection = connections[0] + break + + if not connection: + pytest.skip("No connections available for testing") + + # Verify ConnectionView type + assert_is_view(connection, ConnectionView) + + # Test folder_count property + assert_view_property_access(connection, "folder_count") + + # Verify folder_count matches the length of selected_folders + assert connection.folder_count == len( + connection.selected_folders + ), "folder_count should equal the number of selected_folders" + + # Verify folder_count is non-negative + assert connection.folder_count >= 0, "folder_count should be non-negative" diff --git a/tests_v2/test_domain_v2_contracts.ipynb b/tests_v2/test_domain_v2_contracts.ipynb new file mode 100644 index 000000000..56f450f83 --- /dev/null +++ b/tests_v2/test_domain_v2_contracts.ipynb @@ -0,0 +1,519 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Testing Domain V2 TypedDict Contracts\n", + "\n", + "This notebook demonstrates the new TypedDict-based domain architecture in the Kili Python SDK.\n", + "\n", + "## Features Tested:\n", + "- TypedDict contracts with validation\n", + "- View wrappers for ergonomic access\n", + "- DataFrame adapters\n", + "- Domain helper functions\n", + "\n", + "## Setup\n", + "\n", + "Using the local test API:\n", + "- API_KEY: ``\n", + "- ENDPOINT: `http://localhost:4001/api/label/v2/graphql`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Import the new domain_v2 contracts and utilities\n", + "import json\n", + "from typing import List\n", + "\n", + "# For testing with the API\n", + "from kili.client import Kili\n", + "from kili.domain_v2 import (\n", + " AssetContract,\n", + " AssetView,\n", + " LabelView,\n", + " ProjectView,\n", + " UserContract,\n", + " UserView,\n", + " validate_asset,\n", + " validate_label,\n", + " validate_project,\n", + " validate_user,\n", + ")\n", + "from kili.domain_v2.adapters import ContractValidator, DataFrameAdapter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure API connection\n", + "API_KEY = \"\"\n", + "ENDPOINT = \"http://localhost:4001/api/label/v2/graphql\"\n", + "\n", + "# Initialize Kili client\n", + "kili = Kili(api_key=API_KEY, api_endpoint=ENDPOINT)\n", + "\n", + "print(\"โœ… Connected to Kili API\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Asset Contracts and Validation\n", + "\n", + "Test the AssetContract TypedDict and validation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example: Create a sample asset contract\n", + "sample_asset: AssetContract = {\n", + " \"id\": \"asset-test-123\",\n", + " \"externalId\": \"test-image-001\",\n", + " \"content\": \"https://example.com/image.jpg\",\n", + " \"status\": \"TODO\",\n", + " \"labels\": [],\n", + " \"createdAt\": \"2025-01-15T10:00:00Z\",\n", + " \"updatedAt\": \"2025-01-15T10:00:00Z\",\n", + "}\n", + "\n", + "# Validate the asset contract\n", + "validated_asset = validate_asset(sample_asset)\n", + "print(\"โœ… Asset contract validated successfully\")\n", + "print(json.dumps(validated_asset, indent=2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an AssetView for ergonomic access\n", + "asset_view = AssetView(validated_asset)\n", + "\n", + "print(f\"Display Name: {asset_view.display_name}\")\n", + "print(f\"Has Labels: {asset_view.has_labels}\")\n", + "print(f\"Is Labeled: {asset_view.is_labeled}\")\n", + "print(f\"Label Count: {asset_view.label_count}\")\n", + "\n", + "# Convert back to dict (returns reference, no copy)\n", + "asset_dict = asset_view.to_dict()\n", + "print(f\"\\nOriginal dict ID matches: {asset_dict['id'] == sample_asset['id']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Fetch Real Assets from API\n", + "\n", + "Query assets from the test API and convert to TypedDict contracts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch projects to get a project ID\n", + "try:\n", + " projects = kili.projects(first=1)\n", + " if projects:\n", + " project_id = projects[0][\"id\"]\n", + " print(f\"โœ… Found project: {project_id}\")\n", + " else:\n", + " print(\"โš ๏ธ No projects found. Using sample data.\")\n", + " project_id = None\n", + "except Exception as e:\n", + " print(f\"โš ๏ธ Could not fetch projects: {e}\")\n", + " print(\"Using sample data instead.\")\n", + " project_id = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch assets if we have a project\n", + "if project_id:\n", + " try:\n", + " assets_data = kili.assets(\n", + " project_id=project_id,\n", + " first=5,\n", + " fields=[\n", + " \"id\",\n", + " \"externalId\",\n", + " \"content\",\n", + " \"status\",\n", + " \"createdAt\",\n", + " \"updatedAt\",\n", + " \"labels.id\",\n", + " \"labels.author.email\",\n", + " ],\n", + " )\n", + "\n", + " print(f\"โœ… Fetched {len(assets_data)} assets from API\")\n", + "\n", + " # Validate assets as AssetContracts\n", + " validated_assets: List[AssetContract] = []\n", + " for asset in assets_data:\n", + " try:\n", + " validated = validate_asset(asset) # type: ignore\n", + " validated_assets.append(validated)\n", + " except Exception as e:\n", + " print(f\"โš ๏ธ Validation failed for asset {asset.get('id')}: {e}\")\n", + "\n", + " print(f\"โœ… Validated {len(validated_assets)} assets\")\n", + "\n", + " # Display first asset\n", + " if validated_assets:\n", + " first_asset = AssetView(validated_assets[0])\n", + " print(\"\\nFirst Asset:\")\n", + " print(f\" ID: {first_asset.id}\")\n", + " print(f\" External ID: {first_asset.external_id}\")\n", + " print(f\" Display Name: {first_asset.display_name}\")\n", + " print(f\" Status: {first_asset.status}\")\n", + " print(f\" Label Count: {first_asset.label_count}\")\n", + "\n", + " except Exception as e:\n", + " print(f\"โš ๏ธ Could not fetch assets: {e}\")\n", + " validated_assets = []\n", + "else:\n", + " validated_assets = []" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. DataFrame Adapters\n", + "\n", + "Test converting contracts to pandas DataFrames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use real assets if available, otherwise create sample data\n", + "if validated_assets:\n", + " test_assets = validated_assets\n", + "else:\n", + " # Create sample data for testing\n", + " test_assets = [\n", + " validate_asset(\n", + " {\n", + " \"id\": f\"asset-{i}\",\n", + " \"externalId\": f\"test-{i:03d}\",\n", + " \"content\": f\"https://example.com/image-{i}.jpg\",\n", + " \"status\": \"TODO\" if i % 2 == 0 else \"ONGOING\",\n", + " \"labels\": [],\n", + " \"createdAt\": \"2025-01-15T10:00:00Z\",\n", + " }\n", + " )\n", + " for i in range(5)\n", + " ]\n", + "\n", + "# Convert to DataFrame\n", + "adapter = DataFrameAdapter()\n", + "df = adapter.to_dataframe(test_assets)\n", + "\n", + "print(\"โœ… Converted assets to DataFrame\")\n", + "print(f\"\\nDataFrame shape: {df.shape}\")\n", + "print(f\"\\nColumns: {list(df.columns)}\")\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert back from DataFrame to contracts\n", + "reconstructed_assets = adapter.from_dataframe(df)\n", + "\n", + "print(f\"โœ… Converted DataFrame back to {len(reconstructed_assets)} AssetContracts\")\n", + "print(\"\\nFirst reconstructed asset:\")\n", + "print(json.dumps(reconstructed_assets[0], indent=2, default=str))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Label Contracts and Helper Functions\n", + "\n", + "Test label contracts and domain helper functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from kili.domain_v2.label import filter_labels_by_type, sort_labels_by_created_at\n", + "\n", + "# Create sample labels\n", + "sample_labels = [\n", + " validate_label(\n", + " {\n", + " \"id\": \"label-1\",\n", + " \"labelType\": \"DEFAULT\",\n", + " \"createdAt\": \"2025-01-15T10:00:00Z\",\n", + " \"author\": {\"email\": \"user1@example.com\"},\n", + " }\n", + " ),\n", + " validate_label(\n", + " {\n", + " \"id\": \"label-2\",\n", + " \"labelType\": \"REVIEW\",\n", + " \"createdAt\": \"2025-01-15T09:00:00Z\",\n", + " \"author\": {\"email\": \"user2@example.com\"},\n", + " }\n", + " ),\n", + " validate_label(\n", + " {\n", + " \"id\": \"label-3\",\n", + " \"labelType\": \"DEFAULT\",\n", + " \"createdAt\": \"2025-01-15T11:00:00Z\",\n", + " \"author\": {\"email\": \"user3@example.com\"},\n", + " }\n", + " ),\n", + "]\n", + "\n", + "print(\"Created 3 sample labels\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sort labels by creation date\n", + "sorted_labels = sort_labels_by_created_at(sample_labels)\n", + "\n", + "print(\"Labels sorted by creation date:\")\n", + "for label in sorted_labels:\n", + " view = LabelView(label)\n", + " print(f\" {view.id}: {view.created_at} ({view.label_type})\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter labels by type\n", + "default_labels = filter_labels_by_type(sample_labels, \"DEFAULT\")\n", + "\n", + "print(f\"\\nFiltered to DEFAULT labels: {len(default_labels)} found\")\n", + "for label in default_labels:\n", + " view = LabelView(label)\n", + " print(f\" {view.id}: {view.label_type}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Project and User Contracts\n", + "\n", + "Test project and user contracts with real API data if available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fetch a project from the API\n", + "if project_id:\n", + " try:\n", + " project_data = kili.projects(\n", + " project_id=project_id,\n", + " fields=[\n", + " \"id\",\n", + " \"title\",\n", + " \"description\",\n", + " \"inputType\",\n", + " \"createdAt\",\n", + " \"updatedAt\",\n", + " ],\n", + " )\n", + "\n", + " if project_data:\n", + " project_contract = validate_project(project_data[0]) # type: ignore\n", + " project_view = ProjectView(project_contract)\n", + "\n", + " print(\"โœ… Fetched and validated project:\")\n", + " print(f\" Title: {project_view.title}\")\n", + " print(f\" Description: {project_view.description or 'N/A'}\")\n", + " print(f\" Input Type: {project_view.input_type}\")\n", + " print(f\" Display Name: {project_view.display_name}\")\n", + " except Exception as e:\n", + " print(f\"โš ๏ธ Could not fetch project: {e}\")\n", + "else:\n", + " print(\"โš ๏ธ No project ID available\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a sample user contract\n", + "sample_user: UserContract = {\n", + " \"id\": \"user-123\",\n", + " \"email\": \"test@example.com\",\n", + " \"firstname\": \"Test\",\n", + " \"lastname\": \"User\",\n", + " \"activated\": True,\n", + "}\n", + "\n", + "validated_user = validate_user(sample_user)\n", + "user_view = UserView(validated_user)\n", + "\n", + "print(\"โœ… Created sample user:\")\n", + "print(f\" Email: {user_view.email}\")\n", + "print(f\" Full Name: {user_view.full_name}\")\n", + "print(f\" Display Name: {user_view.display_name}\")\n", + "print(f\" Activated: {user_view.is_activated}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Batch Validation with ContractValidator\n", + "\n", + "Test batch validation with error collection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a mix of valid and invalid asset data\n", + "mixed_data = [\n", + " {\"id\": \"asset-1\", \"externalId\": \"test-1\", \"content\": \"url1\"}, # Valid\n", + " {\"id\": 123, \"externalId\": \"test-2\"}, # Invalid: id should be string\n", + " {\"id\": \"asset-3\", \"externalId\": \"test-3\", \"content\": \"url3\"}, # Valid\n", + " {\"externalId\": \"test-4\"}, # Invalid: missing id\n", + "]\n", + "\n", + "# Batch validate with error collection\n", + "validator = ContractValidator()\n", + "results = validator.validate_batch(mixed_data, AssetContract)\n", + "\n", + "print(\"Validation results:\")\n", + "print(f\" Valid: {len(results['valid'])}\")\n", + "print(f\" Invalid: {len(results['invalid'])}\")\n", + "print(\"\\nErrors:\")\n", + "for error in results[\"errors\"]:\n", + " print(f\" Index {error['index']}: {error['error']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Performance Test\n", + "\n", + "Test performance with a larger dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "# Create a large dataset\n", + "large_dataset = [\n", + " {\n", + " \"id\": f\"asset-{i}\",\n", + " \"externalId\": f\"test-{i:05d}\",\n", + " \"content\": f\"https://example.com/image-{i}.jpg\",\n", + " \"status\": \"TODO\",\n", + " }\n", + " for i in range(1000)\n", + "]\n", + "\n", + "# Time validation\n", + "start = time.time()\n", + "validated = [validate_asset(asset) for asset in large_dataset] # type: ignore\n", + "validation_time = time.time() - start\n", + "\n", + "# Time DataFrame conversion\n", + "start = time.time()\n", + "df = adapter.to_dataframe(validated)\n", + "df_time = time.time() - start\n", + "\n", + "print(\"Performance Test (1000 assets):\")\n", + "print(f\" Validation: {validation_time:.3f}s\")\n", + "print(f\" DataFrame conversion: {df_time:.3f}s\")\n", + "print(f\" Total: {validation_time + df_time:.3f}s\")\n", + "print(f\"\\nDataFrame shape: {df.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. โœ… **TypedDict Contracts**: Type-safe dictionary schemas with `total=False`\n", + "2. โœ… **Runtime Validation**: Using `typeguard.check_type` for contract validation\n", + "3. โœ… **View Wrappers**: Ergonomic property access with frozen dataclasses\n", + "4. โœ… **DataFrame Adapters**: Seamless pandas integration\n", + "5. โœ… **Domain Helpers**: Utility functions for sorting and filtering\n", + "6. โœ… **Batch Validation**: Error collection for multiple items\n", + "7. โœ… **Performance**: Fast validation and conversion (tested with 1000 items)\n", + "\n", + "The new domain_v2 architecture provides:\n", + "- Type safety without runtime overhead (TypedDict)\n", + "- Ergonomic API with View wrappers\n", + "- Easy integration with pandas\n", + "- Backward compatibility with existing dict-based code" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests_v2/test_infrastructure.py b/tests_v2/test_infrastructure.py new file mode 100644 index 000000000..9587ba04b --- /dev/null +++ b/tests_v2/test_infrastructure.py @@ -0,0 +1,78 @@ +"""Test infrastructure validation for domain_v2 integration tests. + +This test file validates that the testing infrastructure is set up correctly +and that basic patterns work as expected. +""" + + +from tests_v2 import assert_is_view, assert_view_has_dict_compatibility + + +def test_fixtures_available(kili_client, api_key, api_endpoint): + """Test that all fixtures are available and configured correctly.""" + # Verify fixtures are provided + assert kili_client is not None, "kili_client fixture should be available" + assert api_key is not None, "api_key fixture should be available" + assert api_endpoint is not None, "api_endpoint fixture should be available" + + # Verify API key format + assert isinstance(api_key, str), "API key should be a string" + assert len(api_key) > 0, "API key should not be empty" + + # Verify endpoint format + assert isinstance(api_endpoint, str), "Endpoint should be a string" + assert api_endpoint.startswith("http"), "Endpoint should be a URL" + + +def test_kili_client_structure(kili_client): + """Test that kili_client has expected domain namespaces.""" + # Verify client has domain namespaces + assert hasattr(kili_client, "assets"), "Client should have assets namespace" + assert hasattr(kili_client, "labels"), "Client should have labels namespace" + assert hasattr(kili_client, "projects"), "Client should have projects namespace" + assert hasattr(kili_client, "users"), "Client should have users namespace" + + +def test_utility_functions_importable(): + """Test that utility functions are importable and callable.""" + # These imports should not raise + from tests_v2 import ( # pylint: disable=import-outside-toplevel + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, + ) + + # Verify they are callable + assert callable(assert_is_view) + assert callable(assert_view_has_dict_compatibility) + assert callable(assert_view_property_access) + + +def test_view_utilities_with_sample_data(): + """Test that View utilities work with sample data.""" + from kili.domain_v2.asset import ( # pylint: disable=import-outside-toplevel + AssetContract, + AssetView, + ) + + # Create sample data + sample_data: AssetContract = { + "id": "test-id-123", + "externalId": "test-external-id", + "content": "https://example.com/test.jpg", + "labels": [], + "isHoneypot": False, + "skipped": False, + } + + # Create View + view = AssetView(sample_data) + + # Test utilities + assert_is_view(view, AssetView) + assert_view_has_dict_compatibility(view) + + # Verify basic properties + assert view.id == "test-id-123" + assert view.external_id == "test-external-id" + assert view.display_name == "test-external-id" diff --git a/tests_v2/test_integrations_view.py b/tests_v2/test_integrations_view.py new file mode 100644 index 000000000..25f495216 --- /dev/null +++ b/tests_v2/test_integrations_view.py @@ -0,0 +1,374 @@ +"""Integration tests for IntegrationView objects returned by the integrations namespace. + +This test file validates that the integrations.list() method correctly returns +IntegrationView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns IntegrationView objects in all modes (list, generator) + - Test IntegrationView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Test filtering by platform and status + - Verify computed properties (is_connected, is_checking, has_error, is_active) + - Ensure mutation methods still return dicts (unchanged) +""" + +import pytest + +from kili.domain_v2.integration import IntegrationView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_integration_views(kili_client): + """Test that integrations.list() in list mode returns IntegrationView objects.""" + # Get integrations in list mode + integrations = kili_client.integrations.list(first=5, as_generator=False) + + # Verify we get a list + assert isinstance( + integrations, list + ), "integrations.list() with as_generator=False should return a list" + + # Skip if no integrations + if not integrations: + pytest.skip("No integrations available for testing") + + # Verify each item is an IntegrationView + for integration in integrations: + assert_is_view(integration, IntegrationView) + + # Verify we can access basic properties + assert hasattr(integration, "id") + assert hasattr(integration, "name") + assert hasattr(integration, "platform") + assert hasattr(integration, "status") + assert hasattr(integration, "organization_id") + + +@pytest.mark.integration() +def test_list_generator_returns_integration_views(kili_client): + """Test that integrations.list() in generator mode returns IntegrationView objects.""" + # Get integrations in generator mode + integrations_gen = kili_client.integrations.list(first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + integrations_from_gen = [] + for i, integration in enumerate(integrations_gen): + if i >= 5: + break + integrations_from_gen.append(integration) + + # Skip if no integrations + if not integrations_from_gen: + pytest.skip("No integrations available for testing") + + # Verify each yielded item is an IntegrationView + for integration in integrations_from_gen: + assert_is_view(integration, IntegrationView) + + # Verify we can access basic properties + assert hasattr(integration, "id") + assert hasattr(integration, "name") + assert hasattr(integration, "platform") + assert hasattr(integration, "status") + + +@pytest.mark.integration() +def test_integration_view_properties(kili_client): + """Test that IntegrationView provides access to all expected properties.""" + # Get first integration + integrations = kili_client.integrations.list(first=1, as_generator=False) + + if not integrations: + pytest.skip("No integrations available for testing") + + integration = integrations[0] + + # Verify IntegrationView type + assert_is_view(integration, IntegrationView) + + # Test core properties exist and are accessible + assert_view_property_access(integration, "id") + assert_view_property_access(integration, "name") + assert_view_property_access(integration, "platform") + assert_view_property_access(integration, "status") + assert_view_property_access(integration, "organization_id") + + # Test that id is not empty + assert integration.id, "Integration id should not be empty" + + # Test that name is not empty + assert integration.name, "Integration name should not be empty" + + # Test that platform is valid + assert integration.platform in [ + "AWS", + "AZURE", + "GCP", + "S3", + None, + ], "platform should be one of AWS, AZURE, GCP, S3, or None" + + # Test that status is valid + assert integration.status in [ + "CONNECTED", + "CHECKING", + "ERROR", + None, + ], "status should be one of CONNECTED, CHECKING, ERROR, or None" + + # Test computed properties + assert_view_property_access(integration, "is_connected") + assert_view_property_access(integration, "is_checking") + assert_view_property_access(integration, "has_error") + assert_view_property_access(integration, "is_active") + assert_view_property_access(integration, "display_name") + + # Test that exactly one status property is True + status_properties = [integration.is_connected, integration.is_checking, integration.has_error] + # At most one should be True (or none if status is None) + assert ( + sum(status_properties) <= 1 + ), "At most one of is_connected, is_checking, has_error should be True" + + # Test that is_active is an alias for is_connected + assert ( + integration.is_active == integration.is_connected + ), "is_active should be an alias for is_connected" + + # Test display_name (should be name or id) + assert integration.display_name, "display_name should not be empty" + if integration.name: + assert integration.display_name == integration.name + else: + assert integration.display_name == integration.id + + +@pytest.mark.integration() +def test_integration_view_dict_compatibility(kili_client): + """Test that IntegrationView maintains backward compatibility via to_dict().""" + # Get first integration + integrations = kili_client.integrations.list(first=1, as_generator=False) + + if not integrations: + pytest.skip("No integrations available for testing") + + integration = integrations[0] + + # Verify IntegrationView type + assert_is_view(integration, IntegrationView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(integration) + + # Get dictionary representation + integration_dict = integration.to_dict() + + # Verify it's a dictionary + assert isinstance(integration_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in integration_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "name" in integration_dict: + assert integration_dict["name"] == integration.name, "Dictionary name should match property" + + if "platform" in integration_dict: + assert ( + integration_dict["platform"] == integration.platform + ), "Dictionary platform should match property" + + if "status" in integration_dict: + assert ( + integration_dict["status"] == integration.status + ), "Dictionary status should match property" + + if "organizationId" in integration_dict: + assert ( + integration_dict["organizationId"] == integration.organization_id + ), "Dictionary organizationId should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert ( + integration_dict is integration._data + ), "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_integration_view_filtering(kili_client): + """Test that IntegrationView objects work correctly with filtering.""" + # Get all integrations + all_integrations = kili_client.integrations.list(first=10, as_generator=False) + + if not all_integrations: + pytest.skip("No integrations available for testing") + + # Get the first integration's platform and status for filtering + first_integration = all_integrations[0] + + # Test filtering by platform (if platform is set) + if first_integration.platform: + platform_integrations = kili_client.integrations.list( + platform=first_integration.platform, first=10, as_generator=False + ) + + # Verify results are IntegrationView objects + for integration in platform_integrations: + assert_is_view(integration, IntegrationView) + # All should have the same platform + assert ( + integration.platform == first_integration.platform + ), "Filtered integrations should have the specified platform" + + # Test filtering by status (if status is set) + if first_integration.status: + status_integrations = kili_client.integrations.list( + status=first_integration.status, first=10, as_generator=False + ) + + # Verify results are IntegrationView objects + for integration in status_integrations: + assert_is_view(integration, IntegrationView) + # All should have the same status + assert ( + integration.status == first_integration.status + ), "Filtered integrations should have the specified status" + + +@pytest.mark.integration() +def test_integration_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Query all integrations - may or may not be empty + # This tests that empty results return an empty list + integrations = kili_client.integrations.list(as_generator=False) + + # Verify we get a list (even if empty) + assert isinstance(integrations, list), "Should return a list even when no results" + + +@pytest.mark.integration() +def test_integration_view_status_properties(kili_client): + """Test that status-related computed properties work correctly.""" + # Get all integrations + integrations = kili_client.integrations.list(first=10, as_generator=False) + + if not integrations: + pytest.skip("No integrations available for testing") + + for integration in integrations: + # Verify IntegrationView type + assert_is_view(integration, IntegrationView) + + # Test status properties based on the actual status + if integration.status == "CONNECTED": + assert integration.is_connected is True + assert integration.is_checking is False + assert integration.has_error is False + elif integration.status == "CHECKING": + assert integration.is_connected is False + assert integration.is_checking is True + assert integration.has_error is False + elif integration.status == "ERROR": + assert integration.is_connected is False + assert integration.is_checking is False + assert integration.has_error is True + else: + # Status is None or unexpected + assert integration.is_connected is False + assert integration.is_checking is False + assert integration.has_error is False + + # Verify is_active is an alias for is_connected + assert integration.is_active == integration.is_connected + + +@pytest.mark.integration() +def test_integration_view_with_fields_parameter(kili_client): + """Test that IntegrationView works correctly with custom fields parameter.""" + # Query with specific fields + integrations = kili_client.integrations.list( + first=1, fields=["id", "name", "platform", "status", "organizationId"], as_generator=False + ) + + if not integrations: + pytest.skip("No integrations available for testing") + + integration = integrations[0] + + # Verify it's still an IntegrationView + assert_is_view(integration, IntegrationView) + + # Verify requested fields are accessible + assert_view_property_access(integration, "id") + assert_view_property_access(integration, "name") + assert_view_property_access(integration, "platform") + assert_view_property_access(integration, "status") + assert_view_property_access(integration, "organization_id") + + +@pytest.mark.integration() +def test_integration_count_method(kili_client): + """Test that integrations.count() works correctly and returns an integer.""" + # Count all integrations + total_count = kili_client.integrations.count() + + # Verify result is an integer + assert isinstance(total_count, int), "count() should return an integer" + assert total_count >= 0, "count() should return a non-negative integer" + + # If there are integrations, test filtered counts + if total_count > 0: + # Get first integration to test filtered counts + integrations = kili_client.integrations.list(first=1, as_generator=False) + if integrations: + first_integration = integrations[0] + + # Count by platform (if platform is set) + if first_integration.platform: + platform_count = kili_client.integrations.count(platform=first_integration.platform) + assert isinstance(platform_count, int), "count() should return an integer" + assert platform_count >= 1, "Should count at least the first integration" + assert platform_count <= total_count, "Filtered count should not exceed total count" + + # Count by status (if status is set) + if first_integration.status: + status_count = kili_client.integrations.count(status=first_integration.status) + assert isinstance(status_count, int), "count() should return an integer" + assert status_count >= 1, "Should count at least the first integration" + assert status_count <= total_count, "Filtered count should not exceed total count" + + +@pytest.mark.integration() +def test_integration_view_display_name(kili_client): + """Test that display_name property works correctly.""" + # Get all integrations + integrations = kili_client.integrations.list(first=5, as_generator=False) + + if not integrations: + pytest.skip("No integrations available for testing") + + for integration in integrations: + # Verify IntegrationView type + assert_is_view(integration, IntegrationView) + + # Test display_name property + assert_view_property_access(integration, "display_name") + + # Verify display_name logic + if integration.name: + assert ( + integration.display_name == integration.name + ), "display_name should return name when name is set" + else: + assert ( + integration.display_name == integration.id + ), "display_name should return id when name is not set" + + # Verify display_name is never empty + assert integration.display_name, "display_name should not be empty" diff --git a/tests_v2/test_issues_view.py b/tests_v2/test_issues_view.py new file mode 100644 index 000000000..8b3fa3b00 --- /dev/null +++ b/tests_v2/test_issues_view.py @@ -0,0 +1,418 @@ +"""Integration tests for IssueView objects returned by the issues namespace. + +This test file validates that the issues.list() method correctly returns +IssueView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns IssueView objects in all modes (list, generator) + - Test IssueView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Test filtering by asset_id, issue_type, and status + - Verify status check properties (is_open, is_solved, is_cancelled, is_question) + - Ensure mutation methods still return dicts (unchanged) +""" + +import pytest + +from kili.domain_v2.issue import IssueView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_issue_views(kili_client): + """Test that issues.list() in list mode returns IssueView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get issues in list mode + issues = kili_client.issues.list(project_id=project_id, first=5, as_generator=False) + + # Verify we get a list + assert isinstance(issues, list), "issues.list() with as_generator=False should return a list" + + # Skip if no issues + if not issues: + pytest.skip(f"No issues available in project {project_id}") + + # Verify each item is an IssueView + for issue in issues: + assert_is_view(issue, IssueView) + + # Verify we can access basic properties + assert hasattr(issue, "id") + assert hasattr(issue, "status") + assert hasattr(issue, "type") + assert hasattr(issue, "asset_id") + + +@pytest.mark.integration() +def test_list_generator_returns_issue_views(kili_client): + """Test that issues.list() in generator mode returns IssueView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get issues in generator mode + issues_gen = kili_client.issues.list(project_id=project_id, first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + issues_from_gen = [] + for i, issue in enumerate(issues_gen): + if i >= 5: + break + issues_from_gen.append(issue) + + # Skip if no issues + if not issues_from_gen: + pytest.skip(f"No issues available in project {project_id}") + + # Verify each yielded item is an IssueView + for issue in issues_from_gen: + assert_is_view(issue, IssueView) + + # Verify we can access basic properties + assert hasattr(issue, "id") + assert hasattr(issue, "status") + assert hasattr(issue, "type") + + +@pytest.mark.integration() +def test_issue_view_properties(kili_client): + """Test that IssueView provides access to all expected properties.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get first issue + issues = kili_client.issues.list(project_id=project_id, first=1, as_generator=False) + + if not issues: + pytest.skip(f"No issues available in project {project_id}") + + issue = issues[0] + + # Verify IssueView type + assert_is_view(issue, IssueView) + + # Test core properties exist and are accessible + assert_view_property_access(issue, "id") + assert_view_property_access(issue, "status") + assert_view_property_access(issue, "type") + assert_view_property_access(issue, "asset_id") + assert_view_property_access(issue, "created_at") + assert_view_property_access(issue, "display_name") + + # Test that id is not empty + assert issue.id, "Issue id should not be empty" + + # Test display_name (should be id) + assert issue.display_name == issue.id + + # Test optional properties + assert_view_property_access(issue, "has_been_seen") + + # Test computed status properties + assert_view_property_access(issue, "is_open") + assert_view_property_access(issue, "is_solved") + assert_view_property_access(issue, "is_cancelled") + assert_view_property_access(issue, "is_question") + + # Verify status is one of the valid values + assert issue.status in ( + "OPEN", + "SOLVED", + "CANCELLED", + ), f"Issue status should be OPEN, SOLVED, or CANCELLED, got {issue.status}" + + # Verify type is one of the valid values + assert issue.type in ( + "ISSUE", + "QUESTION", + ), f"Issue type should be ISSUE or QUESTION, got {issue.type}" + + +@pytest.mark.integration() +def test_issue_view_dict_compatibility(kili_client): + """Test that IssueView maintains backward compatibility via to_dict().""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get first issue + issues = kili_client.issues.list(project_id=project_id, first=1, as_generator=False) + + if not issues: + pytest.skip(f"No issues available in project {project_id}") + + issue = issues[0] + + # Verify IssueView type + assert_is_view(issue, IssueView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(issue) + + # Get dictionary representation + issue_dict = issue.to_dict() + + # Verify it's a dictionary + assert isinstance(issue_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in issue_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "status" in issue_dict: + assert issue_dict["status"] == issue.status, "Dictionary status should match property" + + if "type" in issue_dict: + assert issue_dict["type"] == issue.type, "Dictionary type should match property" + + if "assetId" in issue_dict: + assert issue_dict["assetId"] == issue.asset_id, "Dictionary assetId should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert issue_dict is issue._data, "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_issue_view_filtering(kili_client): + """Test that IssueView objects work correctly with filtering.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get all issues + all_issues = kili_client.issues.list(project_id=project_id, first=10, as_generator=False) + + if not all_issues: + pytest.skip(f"No issues available in project {project_id}") + + # Get the first issue's asset_id + first_asset_id = all_issues[0].asset_id + + # Query for issues by asset_id + filtered_issues = kili_client.issues.list( + project_id=project_id, asset_id=first_asset_id, as_generator=False + ) + + # Verify we got results + assert len(filtered_issues) > 0, "Should get at least one issue with specific asset_id" + + # Verify each result is an IssueView + for issue in filtered_issues: + assert_is_view(issue, IssueView) + + # Verify it has the correct asset_id + assert issue.asset_id == first_asset_id, "Filtered issue should have the requested asset_id" + + +@pytest.mark.integration() +def test_issue_view_status_checks(kili_client): + """Test IssueView status check properties (is_open, is_solved, is_cancelled, is_question).""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get issues + issues = kili_client.issues.list(project_id=project_id, first=10, as_generator=False) + + if not issues: + pytest.skip(f"No issues available in project {project_id}") + + for issue in issues: + assert_is_view(issue, IssueView) + + # Verify exactly one status flag is true + status_flags = [issue.is_open, issue.is_solved, issue.is_cancelled] + assert sum(status_flags) == 1, "Exactly one status flag should be true" + + # Verify status flag matches actual status + if issue.status == "OPEN": + assert issue.is_open is True + assert issue.is_solved is False + assert issue.is_cancelled is False + elif issue.status == "SOLVED": + assert issue.is_open is False + assert issue.is_solved is True + assert issue.is_cancelled is False + elif issue.status == "CANCELLED": + assert issue.is_open is False + assert issue.is_solved is False + assert issue.is_cancelled is True + + # Verify type flag matches actual type + if issue.type == "QUESTION": + assert issue.is_question is True + else: + assert issue.is_question is False + + +@pytest.mark.integration() +def test_issue_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Query with a filter that should return no results + # Using a non-existent asset ID + empty_issues = kili_client.issues.list( + project_id=project_id, asset_id="non-existent-asset-id-12345", as_generator=False + ) + + # Verify we get an empty list + assert isinstance(empty_issues, list), "Should return a list even when no results" + assert len(empty_issues) == 0, "Should return empty list for non-existent asset" + + +@pytest.mark.integration() +def test_mutation_methods_still_return_dicts(kili_client): + """Test that mutation methods (create, solve, cancel, open) still return dicts (unchanged).""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Get an asset to work with + assets = list(kili_client.assets.list(project_id=project_id, first=1, as_generator=False)) + + if not assets: + pytest.skip(f"No assets available in project {project_id}") + + asset = assets[0] + + # Get or create a label for the asset + labels = list( + kili_client.labels.list( + project_id=project_id, asset_id=asset.id, first=1, as_generator=False + ) + ) + + if not labels: + pytest.skip(f"No labels available for asset {asset.id}") + + label = labels[0] + + # Test create() method - should return list of dicts + try: + create_result = kili_client.issues.create( + project_id=project_id, + label_id_array=[label.id], + text_array=["Test issue for integration test"], + ) + + # Verify result is a list + assert isinstance(create_result, list), "create() should return a list" + + if create_result: + # Verify first item is a dict + assert isinstance(create_result[0], dict), "create() should return list of dicts" + assert "id" in create_result[0], "Created issue should have 'id' key" + + # Test solve() method - should return list of dicts + issue_id = create_result[0]["id"] + solve_result = kili_client.issues.solve(issue_ids=[issue_id]) + + # Verify result is a list + assert isinstance(solve_result, list), "solve() should return a list" + assert isinstance(solve_result[0], dict), "solve() should return list of dicts" + + # Test open() method - should return list of dicts + open_result = kili_client.issues.open(issue_ids=[issue_id]) + + # Verify result is a list + assert isinstance(open_result, list), "open() should return a list" + assert isinstance(open_result[0], dict), "open() should return list of dicts" + + # Test cancel() method - should return list of dicts + cancel_result = kili_client.issues.cancel(issue_ids=[issue_id]) + + # Verify result is a list + assert isinstance(cancel_result, list), "cancel() should return a list" + assert isinstance(cancel_result[0], dict), "cancel() should return list of dicts" + + except Exception as e: + # If mutations are not allowed in test environment, skip the test + pytest.skip(f"Mutations not allowed or failed: {e}") + + +@pytest.mark.integration() +def test_issue_view_with_fields_parameter(kili_client): + """Test that IssueView works correctly with custom fields parameter.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Projects now return ProjectView objects + project_id = projects[0].id + + # Query with specific fields + issues = kili_client.issues.list( + project_id=project_id, + first=1, + fields=["id", "status", "type", "assetId", "createdAt"], + as_generator=False, + ) + + if not issues: + pytest.skip(f"No issues available in project {project_id}") + + issue = issues[0] + + # Verify it's still an IssueView + assert_is_view(issue, IssueView) + + # Verify requested fields are accessible + assert_view_property_access(issue, "id") + assert_view_property_access(issue, "status") + assert_view_property_access(issue, "type") + assert_view_property_access(issue, "asset_id") + assert_view_property_access(issue, "created_at") diff --git a/tests_v2/test_labels_view.py b/tests_v2/test_labels_view.py new file mode 100644 index 000000000..e6897ae22 --- /dev/null +++ b/tests_v2/test_labels_view.py @@ -0,0 +1,424 @@ +"""Integration tests for LabelView objects returned by the labels namespace. + +This test file validates that the labels.list() method correctly returns +LabelView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns LabelView objects in all modes (list, generator) + - Test LabelView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Ensure ParsedLabel mode remains unchanged + - Test nested namespaces (predictions, inferences) +""" + +import pytest + +from kili.domain_v2.label import LabelView +from kili.utils.labels.parsing import ParsedLabel +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_label_views(kili_client): + """Test that labels.list() in list mode returns LabelView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get labels in list mode + labels = kili_client.labels.list(project_id=project_id, first=5, as_generator=False) + + # Verify we get a list + assert isinstance(labels, list), "labels.list() with as_generator=False should return a list" + + # Skip if no labels + if not labels: + pytest.skip(f"No labels available in project {project_id}") + + # Verify each item is a LabelView + for label in labels: + assert_is_view(label, LabelView) + + # Verify we can access basic properties + assert hasattr(label, "id") + assert hasattr(label, "label_type") + assert hasattr(label, "author") + + +@pytest.mark.integration() +def test_list_generator_returns_label_views(kili_client): + """Test that labels.list() in generator mode returns LabelView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get labels in generator mode + labels_gen = kili_client.labels.list(project_id=project_id, first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + labels_from_gen = [] + for i, label in enumerate(labels_gen): + if i >= 5: + break + labels_from_gen.append(label) + + # Skip if no labels + if not labels_from_gen: + pytest.skip(f"No labels available in project {project_id}") + + # Verify each yielded item is a LabelView + for label in labels_from_gen: + assert_is_view(label, LabelView) + + # Verify we can access basic properties + assert hasattr(label, "id") + assert hasattr(label, "label_type") + + +@pytest.mark.integration() +def test_label_view_properties(kili_client): + """Test that LabelView provides access to all expected properties.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get first label + labels = kili_client.labels.list(project_id=project_id, first=1, as_generator=False) + + if not labels: + pytest.skip(f"No labels available in project {project_id}") + + label = labels[0] + + # Verify LabelView type + assert_is_view(label, LabelView) + + # Test core properties exist and are accessible + assert_view_property_access(label, "id") + assert_view_property_access(label, "label_type") + assert_view_property_access(label, "author") + assert_view_property_access(label, "json_response") + assert_view_property_access(label, "created_at") + + # Test that id is not empty + assert label.id, "Label id should not be empty" + + # Test optional properties + assert_view_property_access(label, "author_email") + assert_view_property_access(label, "author_id") + assert_view_property_access(label, "updated_at") + assert_view_property_access(label, "model_name") + assert_view_property_access(label, "seconds_to_label") + assert_view_property_access(label, "is_latest") + assert_view_property_access(label, "consensus_mark") + assert_view_property_access(label, "honeypot_mark") + + # Test computed properties + assert_view_property_access(label, "is_prediction") + assert_view_property_access(label, "is_review") + assert_view_property_access(label, "display_name") + + # Verify json_response is a dictionary + assert isinstance(label.json_response, dict), "json_response property should return a dict" + + # Verify label_type is one of the expected values + if label.label_type: + assert label.label_type in ( + "DEFAULT", + "AUTOSAVE", + "PREDICTION", + "INFERENCE", + "REVIEW", + ), f"Unexpected label_type: {label.label_type}" + + # Verify is_prediction logic + if label.label_type in ("PREDICTION", "INFERENCE"): + assert label.is_prediction, "is_prediction should be True for PREDICTION/INFERENCE labels" + else: + assert not label.is_prediction, "is_prediction should be False for non-PREDICTION labels" + + # Verify is_review logic + if label.label_type == "REVIEW": + assert label.is_review, "is_review should be True for REVIEW labels" + else: + assert not label.is_review, "is_review should be False for non-REVIEW labels" + + +@pytest.mark.integration() +def test_label_view_dict_compatibility(kili_client): + """Test that LabelView maintains backward compatibility via to_dict().""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get first label + labels = kili_client.labels.list(project_id=project_id, first=1, as_generator=False) + + if not labels: + pytest.skip(f"No labels available in project {project_id}") + + label = labels[0] + + # Verify LabelView type + assert_is_view(label, LabelView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(label) + + # Get dictionary representation + label_dict = label.to_dict() + + # Verify it's a dictionary + assert isinstance(label_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in label_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "labelType" in label_dict: + assert ( + label_dict["labelType"] == label.label_type + ), "Dictionary labelType should match property" + + if "jsonResponse" in label_dict: + assert ( + label_dict["jsonResponse"] == label.json_response + ), "Dictionary jsonResponse should match property" + + if "createdAt" in label_dict: + assert ( + label_dict["createdAt"] == label.created_at + ), "Dictionary createdAt should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert label_dict is label._data, "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_label_view_with_parsed_label(kili_client): + """Test that ParsedLabel mode still returns ParsedLabel objects (unchanged behavior).""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get labels in parsed_label mode + labels = kili_client.labels.list( + project_id=project_id, first=5, output_format="parsed_label", as_generator=False + ) + + # Skip if no labels + if not labels: + pytest.skip(f"No labels available in project {project_id}") + + # Verify each item is a ParsedLabel (not LabelView) + for label in labels: + assert isinstance( + label, ParsedLabel + ), f"Expected ParsedLabel with output_format='parsed_label', got {type(label).__name__}" + # ParsedLabel should NOT be a LabelView + assert not isinstance(label, LabelView), "ParsedLabel should not be wrapped in LabelView" + + +@pytest.mark.integration() +def test_predictions_list_returns_label_views(kili_client): + """Test that labels.predictions.list() returns LabelView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get predictions in list mode + predictions = kili_client.labels.predictions.list( + project_id=project_id, first=5, as_generator=False + ) + + # Verify we get a list + assert isinstance( + predictions, list + ), "labels.predictions.list() with as_generator=False should return a list" + + # Skip if no predictions + if not predictions: + pytest.skip(f"No prediction labels available in project {project_id}") + + # Verify each item is a LabelView + for prediction in predictions: + assert_is_view(prediction, LabelView) + + # Verify we can access basic properties + assert hasattr(prediction, "id") + assert hasattr(prediction, "label_type") + + # Verify it's actually a PREDICTION label + assert ( + prediction.label_type == "PREDICTION" + ), f"Expected PREDICTION label, got {prediction.label_type}" + + +@pytest.mark.integration() +def test_inferences_list_returns_label_views(kili_client): + """Test that labels.inferences.list() returns LabelView objects.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get inferences in list mode + inferences = kili_client.labels.inferences.list( + project_id=project_id, first=5, as_generator=False + ) + + # Verify we get a list + assert isinstance( + inferences, list + ), "labels.inferences.list() with as_generator=False should return a list" + + # Skip if no inferences + if not inferences: + pytest.skip(f"No inference labels available in project {project_id}") + + # Verify each item is a LabelView + for inference in inferences: + assert_is_view(inference, LabelView) + + # Verify we can access basic properties + assert hasattr(inference, "id") + assert hasattr(inference, "label_type") + + # Verify it's actually an INFERENCE label + assert ( + inference.label_type == "INFERENCE" + ), f"Expected INFERENCE label, got {inference.label_type}" + + +@pytest.mark.integration() +def test_label_view_filtering(kili_client): + """Test that LabelView objects work correctly with filtering.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Get all labels + all_labels = kili_client.labels.list(project_id=project_id, first=10, as_generator=False) + + if not all_labels: + pytest.skip(f"No labels available in project {project_id}") + + # Get the first label's ID + first_label_id = all_labels[0].id + + # Query for specific label by ID + filtered_labels = kili_client.labels.list( + project_id=project_id, label_id=first_label_id, as_generator=False + ) + + # Verify we got results + assert len(filtered_labels) > 0, "Should get at least one label with specific label_id" + + # Verify each result is a LabelView + for label in filtered_labels: + assert_is_view(label, LabelView) + + # Verify it's the correct label + assert label.id == first_label_id, "Filtered label should have the requested ID" + + +@pytest.mark.integration() +def test_label_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Query with a filter that should return no results + # Using a non-existent label ID + empty_labels = kili_client.labels.list( + project_id=project_id, label_id="non-existent-label-id-12345", as_generator=False + ) + + # Verify we get an empty list + assert isinstance(empty_labels, list), "Should return a list even when no results" + assert len(empty_labels) == 0, "Should return empty list for non-existent label" + + +@pytest.mark.integration() +def test_label_view_with_fields_parameter(kili_client): + """Test that LabelView works correctly with custom fields parameter.""" + # Get the first project to test with + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + # Extract project ID from ProjectView object + project_id = projects[0].id + + # Query with specific fields + labels = kili_client.labels.list( + project_id=project_id, + first=1, + fields=["id", "labelType", "createdAt", "jsonResponse"], + as_generator=False, + ) + + if not labels: + pytest.skip(f"No labels available in project {project_id}") + + label = labels[0] + + # Verify it's still a LabelView + assert_is_view(label, LabelView) + + # Verify requested fields are accessible + assert_view_property_access(label, "id") + assert_view_property_access(label, "label_type") + assert_view_property_access(label, "created_at") + assert_view_property_access(label, "json_response") diff --git a/tests_v2/test_notifications_view.py b/tests_v2/test_notifications_view.py new file mode 100644 index 000000000..862992186 --- /dev/null +++ b/tests_v2/test_notifications_view.py @@ -0,0 +1,307 @@ +"""Integration tests for NotificationView objects returned by the notifications namespace. + +This test file validates that the notifications.list() method correctly returns +NotificationView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns NotificationView objects in all modes (list, generator) + - Test NotificationView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Test filtering by has_been_seen and user_id + - Verify computed properties (is_unread, display_name) + - Ensure mutation methods still return dicts (unchanged) +""" + +import pytest + +from kili.domain_v2.notification import NotificationView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_notification_views(kili_client): + """Test that notifications.list() in list mode returns NotificationView objects.""" + # Get notifications in list mode + notifications = kili_client.notifications.list(first=5, as_generator=False) + + # Verify we get a list + assert isinstance( + notifications, list + ), "notifications.list() with as_generator=False should return a list" + + # Skip if no notifications + if not notifications: + pytest.skip("No notifications available for testing") + + # Verify each item is a NotificationView + for notification in notifications: + assert_is_view(notification, NotificationView) + + # Verify we can access basic properties + assert hasattr(notification, "id") + assert hasattr(notification, "message") + assert hasattr(notification, "status") + assert hasattr(notification, "user_id") + + +@pytest.mark.integration() +def test_list_generator_returns_notification_views(kili_client): + """Test that notifications.list() in generator mode returns NotificationView objects.""" + # Get notifications in generator mode + notifications_gen = kili_client.notifications.list(first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + notifications_from_gen = [] + for i, notification in enumerate(notifications_gen): + if i >= 5: + break + notifications_from_gen.append(notification) + + # Skip if no notifications + if not notifications_from_gen: + pytest.skip("No notifications available for testing") + + # Verify each yielded item is a NotificationView + for notification in notifications_from_gen: + assert_is_view(notification, NotificationView) + + # Verify we can access basic properties + assert hasattr(notification, "id") + assert hasattr(notification, "message") + assert hasattr(notification, "status") + + +@pytest.mark.integration() +def test_notification_view_properties(kili_client): + """Test that NotificationView provides access to all expected properties.""" + # Get first notification + notifications = kili_client.notifications.list(first=1, as_generator=False) + + if not notifications: + pytest.skip("No notifications available for testing") + + notification = notifications[0] + + # Verify NotificationView type + assert_is_view(notification, NotificationView) + + # Test core properties exist and are accessible + assert_view_property_access(notification, "id") + assert_view_property_access(notification, "message") + assert_view_property_access(notification, "status") + assert_view_property_access(notification, "user_id") + assert_view_property_access(notification, "created_at") + assert_view_property_access(notification, "has_been_seen") + + # Test that id is not empty + assert notification.id, "Notification id should not be empty" + + # Test that message is not empty + assert notification.message, "Notification message should not be empty" + + # Test computed properties + assert_view_property_access(notification, "is_unread") + assert_view_property_access(notification, "display_name") + + # Test that is_unread is inverse of has_been_seen + assert notification.is_unread == ( + not notification.has_been_seen + ), "is_unread should be the inverse of has_been_seen" + + # Test display_name (should be truncated message or id) + assert notification.display_name, "display_name should not be empty" + if len(notification.message) <= 50: + assert notification.display_name == notification.message + else: + assert notification.display_name == notification.message[:47] + "..." + + # Test optional properties + assert_view_property_access(notification, "url") + + +@pytest.mark.integration() +def test_notification_view_dict_compatibility(kili_client): + """Test that NotificationView maintains backward compatibility via to_dict().""" + # Get first notification + notifications = kili_client.notifications.list(first=1, as_generator=False) + + if not notifications: + pytest.skip("No notifications available for testing") + + notification = notifications[0] + + # Verify NotificationView type + assert_is_view(notification, NotificationView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(notification) + + # Get dictionary representation + notification_dict = notification.to_dict() + + # Verify it's a dictionary + assert isinstance(notification_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in notification_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "message" in notification_dict: + assert ( + notification_dict["message"] == notification.message + ), "Dictionary message should match property" + + if "status" in notification_dict: + assert ( + notification_dict["status"] == notification.status + ), "Dictionary status should match property" + + if "userID" in notification_dict: + assert ( + notification_dict["userID"] == notification.user_id + ), "Dictionary userID should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert ( + notification_dict is notification._data + ), "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_notification_view_filtering(kili_client): + """Test that NotificationView objects work correctly with filtering.""" + # Get all notifications + all_notifications = kili_client.notifications.list(first=10, as_generator=False) + + if not all_notifications: + pytest.skip("No notifications available for testing") + + # Test filtering by has_been_seen=False (unseen notifications) + unseen_notifications = kili_client.notifications.list( + has_been_seen=False, first=10, as_generator=False + ) + + # Verify results are NotificationView objects + for notification in unseen_notifications: + assert_is_view(notification, NotificationView) + # All should be unseen + assert notification.has_been_seen is False, "Filtered notifications should be unseen" + assert notification.is_unread is True, "Filtered notifications should be unread" + + # Test filtering by has_been_seen=True (seen notifications) + seen_notifications = kili_client.notifications.list( + has_been_seen=True, first=10, as_generator=False + ) + + # Verify results are NotificationView objects + for notification in seen_notifications: + assert_is_view(notification, NotificationView) + # All should be seen + assert notification.has_been_seen is True, "Filtered notifications should be seen" + assert notification.is_unread is False, "Filtered notifications should not be unread" + + +@pytest.mark.integration() +def test_notification_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Query with a filter that should return no results + # Using a non-existent notification ID + empty_notifications = kili_client.notifications.list( + notification_id="non-existent-notification-id-12345", as_generator=False + ) + + # Verify we get an empty list + assert isinstance(empty_notifications, list), "Should return a list even when no results" + assert len(empty_notifications) == 0, "Should return empty list for non-existent notification" + + +@pytest.mark.integration() +def test_mutation_methods_still_return_dicts(kili_client): + """Test that mutation methods (create, update) still return dicts (unchanged).""" + # Note: These operations typically require admin privileges + # This test may be skipped in most test environments + + try: + # Test create() method - should return dict + # We'll attempt to create a notification, but this likely requires admin access + # and may fail in test environments + + # Get current user ID (notifications need a user_id) + # This is just to demonstrate the pattern - will likely fail with permission error + create_result = kili_client.notifications.create( + message="Test notification for integration test", + status="info", + url="/test", + user_id="test-user-id", + ) + + # Verify result is a dict + assert isinstance(create_result, dict), "create() should return a dict" + + except Exception as e: + # If mutations are not allowed in test environment, skip the test + pytest.skip(f"Mutations not allowed or failed (expected for non-admin users): {e}") + + +@pytest.mark.integration() +def test_notification_view_with_fields_parameter(kili_client): + """Test that NotificationView works correctly with custom fields parameter.""" + # Query with specific fields + notifications = kili_client.notifications.list( + first=1, + fields=["id", "message", "status", "userID", "createdAt", "hasBeenSeen"], + as_generator=False, + ) + + if not notifications: + pytest.skip("No notifications available for testing") + + notification = notifications[0] + + # Verify it's still a NotificationView + assert_is_view(notification, NotificationView) + + # Verify requested fields are accessible + assert_view_property_access(notification, "id") + assert_view_property_access(notification, "message") + assert_view_property_access(notification, "status") + assert_view_property_access(notification, "user_id") + assert_view_property_access(notification, "created_at") + assert_view_property_access(notification, "has_been_seen") + + +@pytest.mark.integration() +def test_notification_count_method(kili_client): + """Test that notifications.count() works correctly and returns an integer.""" + # Count all notifications + total_count = kili_client.notifications.count() + + # Verify result is an integer + assert isinstance(total_count, int), "count() should return an integer" + assert total_count >= 0, "count() should return a non-negative integer" + + # Count unseen notifications + unseen_count = kili_client.notifications.count(has_been_seen=False) + + # Verify result is an integer + assert isinstance(unseen_count, int), "count() should return an integer" + assert unseen_count >= 0, "count() should return a non-negative integer" + assert unseen_count <= total_count, "Unseen count should not exceed total count" + + # Count seen notifications + seen_count = kili_client.notifications.count(has_been_seen=True) + + # Verify result is an integer + assert isinstance(seen_count, int), "count() should return an integer" + assert seen_count >= 0, "count() should return a non-negative integer" + assert seen_count <= total_count, "Seen count should not exceed total count" + + # Verify seen + unseen = total + assert ( + seen_count + unseen_count == total_count + ), "Seen count + unseen count should equal total count" diff --git a/tests_v2/test_organizations_view.py b/tests_v2/test_organizations_view.py new file mode 100644 index 000000000..9f6679c07 --- /dev/null +++ b/tests_v2/test_organizations_view.py @@ -0,0 +1,289 @@ +"""Integration tests for OrganizationView objects returned by the organizations namespace. + +This test file validates that the organizations.list() method correctly returns +OrganizationView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns OrganizationView objects in all modes (list, generator) + - Test OrganizationView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Test filtering options + - Verify computed properties (display_name, full_address) + - Ensure count method works correctly +""" + +import pytest + +from kili.domain_v2.organization import OrganizationView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_organization_views(kili_client): + """Test that organizations.list() in list mode returns OrganizationView objects.""" + # Get organizations in list mode + organizations = kili_client.organizations.list(first=5, as_generator=False) + + # Verify we get a list + assert isinstance( + organizations, list + ), "organizations.list() with as_generator=False should return a list" + + # Skip if no organizations (unlikely, but possible in isolated test environments) + if not organizations: + pytest.skip("No organizations available for testing") + + # Verify each item is an OrganizationView + for organization in organizations: + assert_is_view(organization, OrganizationView) + + # Verify we can access basic properties + assert hasattr(organization, "id") + assert hasattr(organization, "name") + + +@pytest.mark.integration() +def test_list_generator_returns_organization_views(kili_client): + """Test that organizations.list() in generator mode returns OrganizationView objects.""" + # Get organizations in generator mode + organizations_gen = kili_client.organizations.list(first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + organizations_from_gen = [] + for i, organization in enumerate(organizations_gen): + if i >= 5: + break + organizations_from_gen.append(organization) + + # Skip if no organizations + if not organizations_from_gen: + pytest.skip("No organizations available for testing") + + # Verify each yielded item is an OrganizationView + for organization in organizations_from_gen: + assert_is_view(organization, OrganizationView) + + # Verify we can access basic properties + assert hasattr(organization, "id") + assert hasattr(organization, "name") + + +@pytest.mark.integration() +def test_organization_view_properties(kili_client): + """Test that OrganizationView provides access to all expected properties.""" + # Get first organization + organizations = kili_client.organizations.list(first=1, as_generator=False) + + if not organizations: + pytest.skip("No organizations available for testing") + + organization = organizations[0] + + # Verify OrganizationView type + assert_is_view(organization, OrganizationView) + + # Test core properties exist and are accessible + assert_view_property_access(organization, "id") + assert_view_property_access(organization, "name") + + # Test that id is not empty + assert organization.id, "Organization id should not be empty" + + # Test that name is not empty + assert organization.name, "Organization name should not be empty" + + # Test optional address properties + assert_view_property_access(organization, "address") + assert_view_property_access(organization, "city") + assert_view_property_access(organization, "country") + assert_view_property_access(organization, "zip_code") + + # Test metric properties (may be 0 for new organizations) + assert_view_property_access(organization, "number_of_annotations") + assert_view_property_access(organization, "number_of_labeled_assets") + assert_view_property_access(organization, "number_of_hours") + + # Verify metric properties are non-negative + assert organization.number_of_annotations >= 0, "number_of_annotations should be non-negative" + assert ( + organization.number_of_labeled_assets >= 0 + ), "number_of_labeled_assets should be non-negative" + assert organization.number_of_hours >= 0.0, "number_of_hours should be non-negative" + + # Test computed properties + assert_view_property_access(organization, "display_name") + assert_view_property_access(organization, "full_address") + + # Test display_name (should be name or id) + assert organization.display_name, "display_name should not be empty" + assert organization.display_name == (organization.name or organization.id) + + +@pytest.mark.integration() +def test_organization_view_dict_compatibility(kili_client): + """Test that OrganizationView maintains backward compatibility via to_dict().""" + # Get first organization + organizations = kili_client.organizations.list(first=1, as_generator=False) + + if not organizations: + pytest.skip("No organizations available for testing") + + organization = organizations[0] + + # Verify OrganizationView type + assert_is_view(organization, OrganizationView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(organization) + + # Get dictionary representation + organization_dict = organization.to_dict() + + # Verify it's a dictionary + assert isinstance(organization_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in organization_dict, "Dictionary should have 'id' key" + assert "name" in organization_dict, "Dictionary should have 'name' key" + + # Verify dictionary values match property values + assert organization_dict["id"] == organization.id, "Dictionary id should match property" + assert organization_dict["name"] == organization.name, "Dictionary name should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert ( + organization_dict is organization._data + ), "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_organization_view_filtering(kili_client): + """Test that OrganizationView objects work correctly with filtering.""" + # Get all organizations + all_organizations = kili_client.organizations.list(first=10, as_generator=False) + + if not all_organizations: + pytest.skip("No organizations available for testing") + + # Get the first organization's ID + first_org_id = all_organizations[0].id + + # Query for specific organization by ID + filtered_organizations = kili_client.organizations.list( + organization_id=first_org_id, as_generator=False + ) + + # Verify we got results + assert len(filtered_organizations) > 0, "Should get at least one organization with specific ID" + + # Verify each result is an OrganizationView + for organization in filtered_organizations: + assert_is_view(organization, OrganizationView) + + # Verify it has the correct organization ID + assert organization.id == first_org_id, "Filtered organization should have the requested ID" + + +@pytest.mark.integration() +def test_organization_view_empty_results(kili_client): + """Test that list returns a list even when no specific filters match.""" + # For organizations, we can't easily test "no results" scenarios + # since the API restricts querying other users' organizations + # Instead, just verify we always get a list back + organizations = kili_client.organizations.list( + first=0, # Request 0 items + as_generator=False, + ) + + # Verify we get a list (even if empty) + assert isinstance(organizations, list), "Should return a list even when first=0" + + +@pytest.mark.integration() +def test_organization_view_with_fields_parameter(kili_client): + """Test that OrganizationView works correctly with custom fields parameter.""" + # Query with specific fields (only valid fields from the schema) + organizations = kili_client.organizations.list( + first=1, fields=["id", "name"], as_generator=False + ) + + if not organizations: + pytest.skip("No organizations available for testing") + + organization = organizations[0] + + # Verify it's still an OrganizationView + assert_is_view(organization, OrganizationView) + + # Verify requested fields are accessible + assert_view_property_access(organization, "id") + assert_view_property_access(organization, "name") + + +@pytest.mark.integration() +def test_organization_count_method(kili_client): + """Test that organizations.count() works correctly and returns an integer.""" + # Count all organizations + total_count = kili_client.organizations.count() + + # Verify result is an integer + assert isinstance(total_count, int), "count() should return an integer" + assert total_count > 0, "count() should return at least one organization" + + +@pytest.mark.integration() +def test_organization_view_full_address(kili_client): + """Test the full_address computed property.""" + # Get organizations with standard fields + organizations = kili_client.organizations.list( + first=1, fields=["id", "name"], as_generator=False + ) + + if not organizations: + pytest.skip("No organizations available for testing") + + organization = organizations[0] + + # Verify OrganizationView type + assert_is_view(organization, OrganizationView) + + # Test full_address property exists (even though address fields may not be in the response) + assert_view_property_access(organization, "full_address") + + # Verify full_address is a string + assert isinstance(organization.full_address, str), "full_address should be a string" + + # Since address fields are not in the Organization schema, full_address will likely be empty + # This just tests that the property exists and returns a string + + +@pytest.mark.integration() +def test_organization_metrics_method(kili_client): + """Test that organizations.metrics() works correctly and returns OrganizationMetricsView.""" + # Get an organization first + organizations = kili_client.organizations.list(first=1, as_generator=False) + + if not organizations: + pytest.skip("No organizations available for testing") + + organization = organizations[0] + + # Get metrics for the organization + metrics = kili_client.organizations.metrics(organization_id=organization.id) + + # Verify result is an OrganizationMetricsView with to_dict() method + assert hasattr(metrics, "to_dict"), "metrics() should return an object with to_dict() method" + + # Convert to dict for backward compatibility checks + metrics_dict = metrics.to_dict() + assert isinstance(metrics_dict, dict), "to_dict() should return a dictionary" + + # Verify it contains expected metric fields (based on default fields) + assert ( + "numberOfAnnotations" in metrics_dict or len(metrics_dict) >= 0 + ), "metrics should contain data or be empty" diff --git a/tests_v2/test_projects_view.py b/tests_v2/test_projects_view.py new file mode 100644 index 000000000..114381915 --- /dev/null +++ b/tests_v2/test_projects_view.py @@ -0,0 +1,361 @@ +"""Integration tests for ProjectView objects returned by the projects namespace. + +This test file validates that the projects.list() method correctly returns +ProjectView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns ProjectView objects in all modes (list, generator) + - Test ProjectView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Ensure nested namespaces handle views appropriately + - Test workflow-related properties for V2 projects +""" + +import pytest + +from kili.domain_v2.project import ProjectRoleView, ProjectView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_project_views(kili_client): + """Test that projects.list() in list mode returns ProjectView objects.""" + # Get projects in list mode + projects = kili_client.projects.list(first=5, as_generator=False) + + # Verify we get a list + assert isinstance( + projects, list + ), "projects.list() with as_generator=False should return a list" + + # Skip if no projects + if not projects: + pytest.skip("No projects available for testing") + + # Verify each item is a ProjectView + for project in projects: + assert_is_view(project, ProjectView) + + # Verify we can access basic properties + assert hasattr(project, "id") + assert hasattr(project, "title") + assert hasattr(project, "display_name") + + +@pytest.mark.integration() +def test_list_generator_returns_project_views(kili_client): + """Test that projects.list() in generator mode returns ProjectView objects.""" + # Get projects in generator mode + projects_gen = kili_client.projects.list(first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + projects_from_gen = [] + for i, project in enumerate(projects_gen): + if i >= 5: + break + projects_from_gen.append(project) + + # Skip if no projects + if not projects_from_gen: + pytest.skip("No projects available for testing") + + # Verify each yielded item is a ProjectView + for project in projects_from_gen: + assert_is_view(project, ProjectView) + + # Verify we can access basic properties + assert hasattr(project, "id") + assert hasattr(project, "title") + + +@pytest.mark.integration() +def test_project_view_properties(kili_client): + """Test that ProjectView provides access to all expected properties.""" + # Get first project + projects = kili_client.projects.list(first=1, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + project = projects[0] + + # Verify ProjectView type + assert_is_view(project, ProjectView) + + # Test core properties exist and are accessible + assert_view_property_access(project, "id") + assert_view_property_access(project, "title") + assert_view_property_access(project, "description") + assert_view_property_access(project, "input_type") + assert_view_property_access(project, "json_interface") + assert_view_property_access(project, "display_name") + + # Test that id is not empty + assert project.id, "Project id should not be empty" + + # Test display_name logic (should be title if available, else id) + if project.title: + assert project.display_name == project.title + else: + assert project.display_name == project.id + + # Test optional properties + assert_view_property_access(project, "workflow_version") + assert_view_property_access(project, "steps") + assert_view_property_access(project, "roles") + assert_view_property_access(project, "number_of_assets") + assert_view_property_access(project, "number_of_remaining_assets") + assert_view_property_access(project, "number_of_reviewed_assets") + assert_view_property_access(project, "created_at") + assert_view_property_access(project, "updated_at") + assert_view_property_access(project, "archived") + assert_view_property_access(project, "starred") + + # Test computed properties + assert_view_property_access(project, "is_v2_workflow") + assert_view_property_access(project, "has_honeypot") + assert_view_property_access(project, "progress_percentage") + + # Verify steps is a list + assert isinstance(project.steps, list), "steps property should return a list" + + # Verify roles is a list + assert isinstance(project.roles, list), "roles property should return a list" + + # Verify progress_percentage is a float + assert isinstance(project.progress_percentage, float), "progress_percentage should be a float" + assert ( + 0 <= project.progress_percentage <= 100 + ), "progress_percentage should be between 0 and 100" + + +@pytest.mark.integration() +def test_project_view_dict_compatibility(kili_client): + """Test that ProjectView maintains backward compatibility via to_dict().""" + # Get first project + projects = kili_client.projects.list(first=1, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + project = projects[0] + + # Verify ProjectView type + assert_is_view(project, ProjectView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(project) + + # Get dictionary representation + project_dict = project.to_dict() + + # Verify it's a dictionary + assert isinstance(project_dict, dict), "to_dict() should return a dictionary" + + # Verify dictionary has expected keys + assert "id" in project_dict, "Dictionary should have 'id' key" + + # Verify dictionary values match property values + if "title" in project_dict: + assert project_dict["title"] == project.title, "Dictionary title should match property" + + if "inputType" in project_dict: + assert ( + project_dict["inputType"] == project.input_type + ), "Dictionary inputType should match property" + + if "jsonInterface" in project_dict: + assert ( + project_dict["jsonInterface"] == project.json_interface + ), "Dictionary jsonInterface should match property" + + # Verify to_dict() returns the same reference (zero-copy) + assert project_dict is project._data, "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_project_view_filtering(kili_client): + """Test that ProjectView objects work correctly with filtering.""" + # Get all projects + all_projects = kili_client.projects.list(first=10, as_generator=False) + + if not all_projects: + pytest.skip("No projects available for testing") + + # Get the first project's ID + first_project_id = all_projects[0].id + + # Query for specific project by ID + filtered_projects = kili_client.projects.list(project_id=first_project_id, as_generator=False) + + # Verify we got results + assert len(filtered_projects) > 0, "Should get at least one project with specific project_id" + + # Verify each result is a ProjectView + for project in filtered_projects: + assert_is_view(project, ProjectView) + + # Verify it's the correct project + assert project.id == first_project_id, "Filtered project should have the requested ID" + + +@pytest.mark.integration() +def test_project_view_workflow_properties(kili_client): + """Test ProjectView workflow-related properties for V2 projects.""" + # Get projects + projects = kili_client.projects.list(first=10, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + # Find a V2 workflow project if available + v2_project = None + for project in projects: + if project.is_v2_workflow: + v2_project = project + break + + if v2_project is None: + pytest.skip("No V2 workflow projects available for testing") + + # Verify V2 project has steps + assert_view_property_access(v2_project, "steps") + assert isinstance(v2_project.steps, list), "V2 project should have steps as a list" + + # Verify workflow_version property + assert v2_project.workflow_version == "V2", "V2 project should have workflow_version='V2'" + + # Verify is_v2_workflow computed property + assert v2_project.is_v2_workflow is True, "is_v2_workflow should be True for V2 projects" + + +@pytest.mark.integration() +def test_users_namespace_still_returns_dicts(kili_client): + """Test that projects.users.list() returns ProjectRoleView objects.""" + # Get first project + projects = kili_client.projects.list(first=1, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + project = projects[0] + + # Get project users + users = kili_client.projects.users.list(project_id=project.id, first=5, as_generator=False) + + # Verify we get a list + assert isinstance(users, list), "projects.users.list() should return a list" + + # Skip if no users + if not users: + pytest.skip(f"No users available in project {project.id}") + + # Verify each item is a ProjectRoleView (not a dict or ProjectView) + for user in users: + assert_is_view(user, ProjectRoleView) + assert not isinstance( + user, ProjectView + ), "projects.users.list() should NOT return ProjectView objects" + + # Verify it has the expected properties + assert hasattr(user, "id"), "ProjectRoleView should have an id property" + assert hasattr(user, "role"), "ProjectRoleView should have a role property" + assert hasattr(user, "user_email"), "ProjectRoleView should have a user_email property" + + +@pytest.mark.integration() +def test_project_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Get projects with archived filter + # Use a filter combination that may return no results + empty_projects = kili_client.projects.list( + archived=True, starred=True, first=1, as_generator=False + ) + + # Verify we get a list (may be empty) + assert isinstance(empty_projects, list), "Should return a list even when no results" + + # If we got results, verify they are ProjectView objects + for project in empty_projects: + assert_is_view(project, ProjectView) + + +@pytest.mark.integration() +def test_project_view_with_fields_parameter(kili_client): + """Test that ProjectView works correctly with custom fields parameter.""" + # Query with specific fields + projects = kili_client.projects.list( + first=1, fields=["id", "title", "inputType", "jsonInterface"], as_generator=False + ) + + if not projects: + pytest.skip("No projects available for testing") + + project = projects[0] + + # Verify it's still a ProjectView + assert_is_view(project, ProjectView) + + # Verify requested fields are accessible + assert_view_property_access(project, "id") + assert_view_property_access(project, "title") + assert_view_property_access(project, "input_type") + assert_view_property_access(project, "json_interface") + + +@pytest.mark.integration() +def test_project_view_archived_filter(kili_client): + """Test that ProjectView works correctly with archived filter.""" + # Get only active (non-archived) projects + active_projects = kili_client.projects.list(archived=False, first=5, as_generator=False) + + # Skip if no active projects + if not active_projects: + pytest.skip("No active projects available for testing") + + # Verify each project is a ProjectView and is not archived + for project in active_projects: + assert_is_view(project, ProjectView) + assert ( + project.archived is False + ), "Project should not be archived when filtered with archived=False" + + +@pytest.mark.integration() +def test_project_view_progress_calculation(kili_client): + """Test that ProjectView calculates progress percentage correctly.""" + # Get first project + projects = kili_client.projects.list(first=1, as_generator=False) + + if not projects: + pytest.skip("No projects available for testing") + + project = projects[0] + + # Verify ProjectView type + assert_is_view(project, ProjectView) + + # Get progress percentage + progress = project.progress_percentage + + # Verify progress is a valid percentage + assert isinstance(progress, float), "progress_percentage should be a float" + assert 0 <= progress <= 100, f"progress_percentage should be between 0 and 100, got {progress}" + + # Verify progress calculation logic + total = project.number_of_assets + remaining = project.number_of_remaining_assets + + if total == 0: + assert progress == 0.0, "progress should be 0 when total assets is 0" + else: + completed = total - remaining + expected_progress = (completed / total) * 100 + assert ( + abs(progress - expected_progress) < 0.01 + ), f"progress calculation incorrect: expected {expected_progress}, got {progress}" diff --git a/tests_v2/test_tags_view.py b/tests_v2/test_tags_view.py new file mode 100644 index 000000000..0365b03ad --- /dev/null +++ b/tests_v2/test_tags_view.py @@ -0,0 +1,98 @@ +"""Integration tests for TagView objects returned by the tags namespace.""" + +import pytest + +from kili.domain_v2.tag import TagView +from tests_v2 import assert_is_view, assert_view_has_dict_compatibility, assert_view_property_access + + +@pytest.mark.integration() +def test_list_returns_tag_views(kili_client): + """Test that tags.list() returns TagView objects.""" + tags = kili_client.tags.list() + + assert isinstance(tags, list), "tags.list() should return a list" + + if not tags: + pytest.skip("No tags available for testing") + + for tag in tags: + assert_is_view(tag, TagView) + assert hasattr(tag, "id") + assert hasattr(tag, "label") + + +@pytest.mark.integration() +def test_tag_view_properties(kili_client): + """Test that TagView provides access to all expected properties.""" + tags = kili_client.tags.list() + + if not tags: + pytest.skip("No tags available for testing") + + tag = tags[0] + assert_is_view(tag, TagView) + + assert_view_property_access(tag, "id") + assert_view_property_access(tag, "label") + assert_view_property_access(tag, "organization_id") + assert_view_property_access(tag, "display_name") + + assert tag.id, "Tag id should not be empty" + assert tag.display_name, "Tag display_name should not be empty" + + +@pytest.mark.integration() +def test_tag_view_dict_compatibility(kili_client): + """Test that TagView maintains backward compatibility via to_dict().""" + tags = kili_client.tags.list() + + if not tags: + pytest.skip("No tags available for testing") + + tag = tags[0] + assert_is_view(tag, TagView) + assert_view_has_dict_compatibility(tag) + + tag_dict = tag.to_dict() + assert isinstance(tag_dict, dict), "to_dict() should return a dictionary" + assert "id" in tag_dict, "Dictionary should have 'id' key" + assert tag_dict is tag._data, "to_dict() should return the same reference as _data" + + +@pytest.mark.integration() +def test_tag_view_with_project_filter(kili_client): + """Test listing tags for a specific project.""" + projects = list(kili_client.projects.list(first=1, as_generator=False)) + + if not projects: + pytest.skip("No projects available for testing") + + project_id = projects[0].id + project_tags = kili_client.tags.list(project_id=project_id) + + assert isinstance(project_tags, list), "Should return a list" + + for tag in project_tags: + assert_is_view(tag, TagView) + + +@pytest.mark.integration() +def test_tag_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + tags = kili_client.tags.list() + assert isinstance(tags, list), "Should return a list even when no results" + + +@pytest.mark.integration() +def test_tag_view_with_fields_parameter(kili_client): + """Test that TagView works correctly with custom fields parameter.""" + tags = kili_client.tags.list(fields=["id", "label"]) + + if not tags: + pytest.skip("No tags available for testing") + + tag = tags[0] + assert_is_view(tag, TagView) + assert_view_property_access(tag, "id") + assert_view_property_access(tag, "label") diff --git a/tests_v2/test_users_view.py b/tests_v2/test_users_view.py new file mode 100644 index 000000000..088f91903 --- /dev/null +++ b/tests_v2/test_users_view.py @@ -0,0 +1,310 @@ +"""Integration tests for UserView objects returned by the users namespace. + +This test file validates that the users.list() method correctly returns +UserView objects instead of dictionaries, and that these objects provide +proper property access and backward compatibility. + +Test Strategy: + - Verify list() returns UserView objects in all modes (list, generator) + - Test UserView property access for common properties + - Validate backward compatibility with dictionary interface via to_dict() + - Ensure filtering capabilities work correctly + - Verify empty result handling + - Test mutation methods still return dicts (unchanged) +""" + +import pytest + +from kili.domain_v2.user import UserView +from tests_v2 import ( + assert_is_view, + assert_view_has_dict_compatibility, + assert_view_property_access, +) + + +@pytest.mark.integration() +def test_list_returns_user_views(kili_client): + """Test that users.list() in list mode returns UserView objects.""" + # Get users in list mode + users = kili_client.users.list(first=5, as_generator=False) + + # Verify we get a list + assert isinstance(users, list), "users.list() with as_generator=False should return a list" + + # Skip if no users + if not users: + pytest.skip("No users available for testing") + + # Verify each item is a UserView + for user in users: + assert_is_view(user, UserView) + + # Verify we can access basic properties + assert hasattr(user, "id") + assert hasattr(user, "email") + assert hasattr(user, "display_name") + + +@pytest.mark.integration() +def test_list_generator_returns_user_views(kili_client): + """Test that users.list() in generator mode returns UserView objects.""" + # Get users in generator mode + users_gen = kili_client.users.list(first=5, as_generator=True) + + # Take first 5 items from generator (or fewer if less available) + users_from_gen = [] + for i, user in enumerate(users_gen): + if i >= 5: + break + users_from_gen.append(user) + + # Skip if no users + if not users_from_gen: + pytest.skip("No users available for testing") + + # Verify each yielded item is a UserView + for user in users_from_gen: + assert_is_view(user, UserView) + + # Verify we can access basic properties + assert hasattr(user, "id") + assert hasattr(user, "email") + + +@pytest.mark.integration() +def test_user_view_properties(kili_client): + """Test that UserView provides access to all expected properties.""" + # Get first user + users = kili_client.users.list(first=1, as_generator=False) + + if not users: + pytest.skip("No users available for testing") + + user = users[0] + + # Verify UserView type + assert_is_view(user, UserView) + + # Test core properties exist and are accessible + assert_view_property_access(user, "id") + assert_view_property_access(user, "email") + assert_view_property_access(user, "name") + assert_view_property_access(user, "firstname") + assert_view_property_access(user, "lastname") + assert_view_property_access(user, "activated") + assert_view_property_access(user, "organization_id") + assert_view_property_access(user, "display_name") + assert_view_property_access(user, "full_name") + + # Test that id and email are not empty + assert user.id, "User id should not be empty" + assert user.email, "User email should not be empty" + + # Test display_name logic (should be name if available, else email) + if user.name: + assert user.display_name == user.name or user.display_name == user.email + else: + assert user.display_name == user.email or user.display_name == user.id + + # Test full_name logic + if user.firstname or user.lastname: + expected_full = f"{user.firstname} {user.lastname}".strip() + assert user.full_name == expected_full + else: + assert user.full_name in (user.name, user.email) + + # Test boolean property + assert isinstance(user.activated, bool), "activated should be a boolean" + + +@pytest.mark.integration() +def test_user_view_dict_compatibility(kili_client): + """Test that UserView objects maintain backward compatibility with dict interface.""" + # Get first user + users = kili_client.users.list(first=1, as_generator=False) + + if not users: + pytest.skip("No users available for testing") + + user = users[0] + + # Verify UserView type + assert_is_view(user, UserView) + + # Test dictionary compatibility + assert_view_has_dict_compatibility(user) + + # Verify to_dict() returns expected data + user_dict = user.to_dict() + assert isinstance(user_dict, dict) + assert "id" in user_dict + assert "email" in user_dict + + # Verify properties match dictionary values + assert user.id == user_dict.get("id") + assert user.email == user_dict.get("email") + assert user.firstname == user_dict.get("firstname") + assert user.lastname == user_dict.get("lastname") + + +@pytest.mark.integration() +def test_user_view_filtering(kili_client): + """Test that filtering capabilities work with UserView objects.""" + # Get all users first + all_users = kili_client.users.list(first=10, as_generator=False) + + if not all_users: + pytest.skip("No users available for testing") + + # Get first user's email and organization + first_user = all_users[0] + test_email = first_user.email + test_org_id = first_user.organization_id + + # Filter by email + users_by_email = kili_client.users.list(email=test_email, as_generator=False) + assert isinstance(users_by_email, list) + + # Verify filtered results + if users_by_email: + for user in users_by_email: + assert_is_view(user, UserView) + assert user.email == test_email + + # Filter by organization_id + if test_org_id: + users_by_org = kili_client.users.list( + organization_id=test_org_id, first=5, as_generator=False + ) + assert isinstance(users_by_org, list) + + if users_by_org: + for user in users_by_org: + assert_is_view(user, UserView) + assert user.organization_id == test_org_id + + +@pytest.mark.integration() +def test_user_view_empty_results(kili_client): + """Test that empty results are handled correctly.""" + # Query with a filter that should return no results + users = kili_client.users.list(email="nonexistent@example.com", as_generator=False) + + # Verify we get an empty list (not None or error) + assert isinstance(users, list) + assert len(users) == 0 + + +@pytest.mark.integration() +def test_user_view_with_fields_parameter(kili_client): + """Test that custom fields parameter works with UserView objects.""" + # Request specific fields + users = kili_client.users.list( + first=3, fields=("id", "email", "firstname", "lastname", "activated"), as_generator=False + ) + + if not users: + pytest.skip("No users available for testing") + + # Verify we get UserView objects + for user in users: + assert_is_view(user, UserView) + + # Verify requested fields are accessible + assert user.id + assert user.email + # firstname/lastname may be empty but should be accessible + _ = user.firstname + _ = user.lastname + assert isinstance(user.activated, bool) + + +@pytest.mark.integration() +def test_mutation_methods_still_return_dicts(kili_client): + """Verify that create/update methods still return dictionaries (unchanged). + + This test verifies that mutation methods (create, update, update_password) + continue to return Dict[Literal["id"], str] as expected, since they return + mutation results, not full user objects. + """ + # Note: We cannot easily test create() without side effects, + # but we can verify the type annotations and docstrings are correct + + # Get a user for reference + users = kili_client.users.list(first=1, as_generator=False) + + if not users: + pytest.skip("No users available for testing") + + # Verify that the namespace has the mutation methods + assert hasattr(kili_client.users, "create") + assert hasattr(kili_client.users, "update") + assert hasattr(kili_client.users, "update_password") + + # Verify count() returns int + count = kili_client.users.count() + assert isinstance(count, int) + assert count >= len(users) + + +@pytest.mark.integration() +def test_user_view_organization_role(kili_client): + """Test organization_role property and is_admin computed property.""" + # Get users - Note: organizationRole returns a string like "ADMIN", not a dict + users = kili_client.users.list( + first=5, fields=("id", "email", "organizationRole"), as_generator=False + ) + + if not users: + pytest.skip("No users available for testing") + + for user in users: + assert_is_view(user, UserView) + + # Test organization_role property + # In the actual API, organizationRole can be a string or dict depending on query + org_role = user.organization_role + # The property should be accessible even if None + # Note: based on actual API behavior, this might be a string or dict + + # Test is_admin computed property + assert isinstance(user.is_admin, bool) + + +@pytest.mark.integration() +def test_user_view_timestamps(kili_client): + """Test timestamp properties (created_at, updated_at). + + Note: lastSeenAt is not available in the User API schema. + """ + # Get users with timestamp fields (note: lastSeenAt doesn't exist in API) + users = kili_client.users.list( + first=3, fields=("id", "email", "createdAt", "updatedAt"), as_generator=False + ) + + if not users: + pytest.skip("No users available for testing") + + user = users[0] + assert_is_view(user, UserView) + + # Test timestamp properties (they may be None or ISO timestamp strings) + created = user.created_at + updated = user.updated_at + + # created_at should typically exist + if created: + assert isinstance(created, str) + # Basic ISO format check (YYYY-MM-DD) + assert len(created) >= 10 + + # updated_at should typically exist + if updated: + assert isinstance(updated, str) + assert len(updated) >= 10 + + # Test last_seen_at property (will be None if not in data) + last_seen = user.last_seen_at + # Should not raise error even if field not present + assert last_seen is None or isinstance(last_seen, str)