diff --git a/src/kili/domain_api/assets.py b/src/kili/domain_api/assets.py index 3fabde665..54d1653ce 100644 --- a/src/kili/domain_api/assets.py +++ b/src/kili/domain_api/assets.py @@ -257,6 +257,7 @@ class AssetsNamespace(DomainNamespace): # pylint: disable=too-many-public-metho - create_pdf(): Create PDF assets - create_text(): Create plain text assets - create_rich_text(): Create rich-text formatted text assets + - create_audio(): Create audio assets - delete(): Delete assets from projects - add_metadata(): Add metadata to assets - set_metadata(): Set metadata on assets @@ -1311,6 +1312,102 @@ def create_rich_text( **kwargs, ) + @overload + def create_audio( + self, + *, + project_id: str, + content: Union[str, dict], + external_id: Optional[str] = None, + json_metadata: Optional[dict] = None, + wait_until_availability: bool = True, + **kwargs, + ) -> dict[Literal["id", "asset_ids"], Union[str, List[str]]]: + ... + + @overload + def create_audio( + self, + *, + project_id: str, + content_array: Union[List[str], List[dict]], + external_id_array: Optional[List[str]] = None, + json_metadata_array: Optional[List[dict]] = None, + disable_tqdm: Optional[bool] = None, + wait_until_availability: bool = True, + **kwargs, + ) -> dict[Literal["id", "asset_ids"], Union[str, List[str]]]: + ... + + @typechecked + def create_audio( + self, + *, + project_id: str, + content: Optional[Union[str, dict]] = None, + content_array: Optional[Union[List[str], List[dict]]] = None, + external_id: Optional[str] = None, + external_id_array: Optional[List[str]] = None, + json_metadata: Optional[dict] = None, + json_metadata_array: Optional[List[dict]] = None, + disable_tqdm: Optional[bool] = None, + wait_until_availability: bool = True, + **kwargs, + ) -> dict[Literal["id", "asset_ids"], Union[str, List[str]]]: + """Create audio assets in a project. + + Args: + project_id: Identifier of the project + content: URL or local file path to an audio file + content_array: List of URLs or local file paths to audio files + external_id: External id to identify the asset + external_id_array: List of external ids given to identify the assets + json_metadata: The metadata given to the asset + json_metadata_array: The metadata given to each asset + disable_tqdm: If True, the progress bar will be disabled + wait_until_availability: If True, waits until assets are fully processed + **kwargs: Additional arguments (e.g., is_honeypot) + + Returns: + A dictionary with project id and list of created asset ids + + Examples: + >>> # Create single audio asset + >>> result = kili.assets.create_audio( + ... project_id="my_project", + ... content="https://example.com/audio.mp3" + ... ) + + >>> # Create multiple audio assets + >>> result = kili.assets.create_audio( + ... project_id="my_project", + ... content_array=["https://example.com/audio1.mp3", "https://example.com/audio2.wav"] + ... ) + + >>> # Create audio with metadata + >>> result = kili.assets.create_audio( + ... project_id="my_project", + ... content="https://example.com/audio.mp3", + ... json_metadata={"speaker": "John Doe"} + ... ) + """ + if content is not None: + content_array = cast(Union[list[str], list[dict]], [content]) + if external_id is not None: + external_id_array = [external_id] + if json_metadata is not None: + json_metadata_array = [json_metadata] + + return self._client.append_many_to_dataset( + project_id=project_id, + content_array=content_array, + external_id_array=external_id_array, + json_metadata_array=json_metadata_array, + disable_tqdm=disable_tqdm, + wait_until_availability=wait_until_availability, + **kwargs, + ) + @overload def delete( self, diff --git a/src/kili/services/asset_import/__init__.py b/src/kili/services/asset_import/__init__.py index c605fa92d..e10669cfb 100644 --- a/src/kili/services/asset_import/__init__.py +++ b/src/kili/services/asset_import/__init__.py @@ -7,6 +7,7 @@ ImportValidationError, ) +from .audio import AudioDataImporter from .base import ( BaseAbstractAssetImporter, LoggerParams, @@ -24,6 +25,7 @@ from kili.client import Kili importer_by_type: dict[str, type[BaseAbstractAssetImporter]] = { + "AUDIO": AudioDataImporter, "PDF": PdfDataImporter, "IMAGE": ImageDataImporter, "GEOSPATIAL": ImageDataImporter, diff --git a/src/kili/services/asset_import/audio.py b/src/kili/services/asset_import/audio.py new file mode 100644 index 000000000..931afbfd7 --- /dev/null +++ b/src/kili/services/asset_import/audio.py @@ -0,0 +1,63 @@ +"""Functions to import assets into an AUDIO project.""" + +import os +from enum import Enum + +from kili.core.helpers import is_url +from kili.domain.project import InputType + +from .base import ( + BaseAbstractAssetImporter, + BatchParams, + ContentBatchImporter, +) +from .exceptions import ImportValidationError +from .types import AssetLike + + +class AudioDataType(Enum): + """Audio data type.""" + + LOCAL_FILE = "LOCAL_FILE" + HOSTED_FILE = "HOSTED_FILE" + + +class AudioDataImporter(BaseAbstractAssetImporter): + """Class for importing data into an AUDIO project.""" + + @staticmethod + def get_data_type(assets: list[AssetLike]) -> AudioDataType: + """Determine the type of data to upload from the service payload.""" + content_array = [asset.get("content", "") for asset in assets] + has_local_file = any(os.path.exists(content) for content in content_array) # type: ignore + has_hosted_file = any(is_url(content) for content in content_array) + if has_local_file and has_hosted_file: + raise ImportValidationError( + """ + Cannot upload hosted data and local files at the same time. + Please separate the assets into 2 calls + """ + ) + if has_local_file: + return AudioDataType.LOCAL_FILE + return AudioDataType.HOSTED_FILE + + def import_assets(self, assets: list[AssetLike], input_type: InputType): + """Import AUDIO assets into Kili.""" + self._check_upload_is_allowed(assets) + data_type = self.get_data_type(assets) + assets = self.filter_duplicate_external_ids(assets) + if data_type == AudioDataType.LOCAL_FILE: + assets = self.filter_local_assets(assets, self.raise_error) + batch_params = BatchParams(is_hosted=False, is_asynchronous=False) + batch_importer = ContentBatchImporter( + self.kili, self.project_params, batch_params, self.pbar + ) + elif data_type == AudioDataType.HOSTED_FILE: + batch_params = BatchParams(is_hosted=True, is_asynchronous=False) + batch_importer = ContentBatchImporter( + self.kili, self.project_params, batch_params, self.pbar + ) + else: + raise ImportValidationError + return self.import_assets_by_batch(assets, batch_importer) diff --git a/tests/unit/domain_api/test_assets.py b/tests/unit/domain_api/test_assets.py index 47c3f6ba0..1d32c3656 100644 --- a/tests/unit/domain_api/test_assets.py +++ b/tests/unit/domain_api/test_assets.py @@ -273,6 +273,27 @@ def test_create_image_assets(self, assets_namespace, mock_client): wait_until_availability=True, ) + def test_create_audio_assets(self, assets_namespace, mock_client): + """Test create_audio method delegates to client.""" + expected_result = {"id": "project_123", "asset_ids": ["asset1", "asset2"]} + mock_client.append_many_to_dataset.return_value = expected_result + + result = assets_namespace.create_audio( + project_id="project_123", + content_array=["https://example.com/audio.mp3"], + external_id_array=["ext1"], + ) + + assert result == expected_result + mock_client.append_many_to_dataset.assert_called_once_with( + project_id="project_123", + content_array=["https://example.com/audio.mp3"], + external_id_array=["ext1"], + json_metadata_array=None, + disable_tqdm=None, + wait_until_availability=True, + ) + def test_delete_assets(self, assets_namespace, mock_client): """Test delete method delegates to client.""" expected_result = {"id": "project_123"}