diff --git a/.github/workflows/retro_copat_test.yml b/.github/workflows/retro_copat_test.yml new file mode 100644 index 00000000..f17e997a --- /dev/null +++ b/.github/workflows/retro_copat_test.yml @@ -0,0 +1,29 @@ +name: Run Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + pip install -e ".[past]" + + - name: Run tests with coverage + run: | + pytest tests/unit/coal/ --cov=cosmotech.coal --cov-report=term diff --git a/cosmotech/coal/cosmotech_api/__init__.py b/cosmotech/coal/cosmotech_api/__init__.py index a4b2ff50..c33a6a1a 100644 --- a/cosmotech/coal/cosmotech_api/__init__.py +++ b/cosmotech/coal/cosmotech_api/__init__.py @@ -17,7 +17,8 @@ ) from cosmotech.coal.utils.semver import semver_of -csm_version = semver_of('cosmotech_api') + +csm_version = semver_of("cosmotech_api") if csm_version.major < 5: # Re-export functions from the twin_data_layer module from cosmotech.coal.cosmotech_api.twin_data_layer import ( diff --git a/cosmotech/coal/cosmotech_api/dataset/download/file.py b/cosmotech/coal/cosmotech_api/dataset/download/file.py index 418459d0..ffd1961b 100644 --- a/cosmotech/coal/cosmotech_api/dataset/download/file.py +++ b/cosmotech/coal/cosmotech_api/dataset/download/file.py @@ -52,9 +52,7 @@ def process_xls(target_file) -> Dict[str, Any]: content[sheet_name].append(new_row) row_count += 1 - LOGGER.debug( - T("coal.services.dataset.sheet_processed").format(sheet_name=sheet_name, rows=row_count) - ) + LOGGER.debug(T("coal.services.dataset.sheet_processed").format(sheet_name=sheet_name, rows=row_count)) return content @@ -87,9 +85,7 @@ def process_csv(target_file) -> Dict[str, Any]: content[current_filename].append(new_row) row_count += 1 - LOGGER.debug( - T("coal.services.dataset.csv_processed").format(file_name=current_filename, rows=row_count) - ) + LOGGER.debug(T("coal.services.dataset.csv_processed").format(file_name=current_filename, rows=row_count)) return content @@ -107,9 +103,7 @@ def process_json(target_file) -> Dict[str, Any]: else: item_count = 1 - LOGGER.debug( - T("coal.services.dataset.json_processed").format(file_name=current_filename, items=item_count) - ) + LOGGER.debug(T("coal.services.dataset.json_processed").format(file_name=current_filename, items=item_count)) return content @@ -118,12 +112,10 @@ def process_txt(target_file) -> Dict[str, Any]: LOGGER.debug(T("coal.services.dataset.processing_text").format(file_name=target_file)) with open(target_file, "r") as _file: current_filename = os.path.basename(target_file) - content[current_filename] = "".join(line for line in _file) + content[current_filename] = _file.read() line_count = content[current_filename].count("\n") + 1 - LOGGER.debug( - T("coal.services.dataset.text_processed").format(file_name=current_filename, lines=line_count) - ) + LOGGER.debug(T("coal.services.dataset.text_processed").format(file_name=current_filename, lines=line_count)) return content @@ -140,6 +132,7 @@ def timed_read_file(file_name, file): else: content.update(process_txt(file)) return content + return timed_read_file(file_name, file) diff --git a/cosmotech/coal/cosmotech_api/dataset/download/twingraph.py b/cosmotech/coal/cosmotech_api/dataset/download/twingraph.py index e39cfda9..f28cb9db 100644 --- a/cosmotech/coal/cosmotech_api/dataset/download/twingraph.py +++ b/cosmotech/coal/cosmotech_api/dataset/download/twingraph.py @@ -62,7 +62,9 @@ def download_twingraph_dataset( # Query edges edges_start = time.time() LOGGER.debug(T("coal.services.dataset.twingraph_querying_edges").format(dataset_id=dataset_id)) - edges_query = cosmotech_api.DatasetTwinGraphQuery(query="MATCH(n)-[r]->(m) RETURN n as src, r as rel, m as dest") + edges_query = cosmotech_api.DatasetTwinGraphQuery( + query="MATCH(n)-[r]->(m) RETURN n as src, r as rel, m as dest" + ) edges = dataset_api.twingraph_query( organization_id=organization_id, diff --git a/cosmotech/coal/cosmotech_api/runner/datasets.py b/cosmotech/coal/cosmotech_api/runner/datasets.py index 6b19744c..71b23208 100644 --- a/cosmotech/coal/cosmotech_api/runner/datasets.py +++ b/cosmotech/coal/cosmotech_api/runner/datasets.py @@ -57,7 +57,7 @@ def download_dataset( read_files: bool = True, ) -> Dict[str, Any]: """ - retro-compatibility to cosmo-api v4 + retro-compatibility to cosmo-api v4 """ from cosmotech.coal.utils.semver import semver_of @@ -90,9 +90,9 @@ def download_dataset_v5( # Get dataset information with get_api_client()[0] as api_client: dataset_api_instance = DatasetApi(api_client) - dataset = dataset_api_instance.get_dataset(organization_id=organization_id, - workspace_id=workspace_id, - dataset_id=dataset_id) + dataset = dataset_api_instance.get_dataset( + organization_id=organization_id, workspace_id=workspace_id, dataset_id=dataset_id + ) content = dict() tmp_dataset_dir = tempfile.mkdtemp() @@ -223,9 +223,7 @@ def download_dataset_v4( } -def download_dataset_process( - _dataset_id, organization_id, workspace_id, read_files, _return_dict, _error_dict -): +def download_dataset_process(_dataset_id, organization_id, workspace_id, read_files, _return_dict, _error_dict): """ Process function for downloading a dataset in a separate process. diff --git a/cosmotech/coal/utils/api.py b/cosmotech/coal/utils/api.py deleted file mode 100644 index 757fb835..00000000 --- a/cosmotech/coal/utils/api.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (C) - 2023 - 2025 - Cosmo Tech -# This document and all information contained herein is the exclusive property - -# including all intellectual property rights pertaining thereto - of Cosmo Tech. -# Any use, reproduction, translation, broadcasting, transmission, distribution, -# etc., to any person is prohibited unless it has been previously and -# specifically authorized by written means by Cosmo Tech. - -import json -import pathlib -from typing import Optional - -import cosmotech_api -import yaml -from cosmotech_api.api.solution_api import Solution -from cosmotech_api.api.solution_api import SolutionApi -from cosmotech_api.api.workspace_api import Workspace -from cosmotech_api.api.workspace_api import WorkspaceApi -from cosmotech_api.exceptions import ServiceException - -from cosmotech.coal.cosmotech_api.connection import get_api_client -from cosmotech.coal.utils.logger import LOGGER -from cosmotech.orchestrator.utils.translate import T - - -def read_solution_file(solution_file) -> Optional[Solution]: - solution_path = pathlib.Path(solution_file) - if solution_path.suffix in [".yaml", ".yml"]: - open_function = yaml.safe_load - elif solution_path.suffix == ".json": - open_function = json.load - else: - LOGGER.error(T("coal.cosmotech_api.solution.invalid_file").format(file=solution_file)) - return None - with solution_path.open() as _sf: - solution_content = open_function(_sf) - LOGGER.info(T("coal.cosmotech_api.solution.loaded").format(path=solution_path.absolute())) - _solution = Solution( - _configuration=cosmotech_api.Configuration(), - _spec_property_naming=True, - **solution_content, - ) - LOGGER.debug( - T("coal.services.api.solution_debug").format(solution=json.dumps(_solution.to_dict(), indent=2, default=str)) - ) - return _solution - - -def get_solution(organization_id, workspace_id) -> Optional[Solution]: - LOGGER.info(T("coal.cosmotech_api.solution.api_configured")) - with get_api_client()[0] as api_client: - api_w = WorkspaceApi(api_client) - - LOGGER.info(T("coal.cosmotech_api.solution.loading_workspace")) - try: - r_data: Workspace = api_w.find_workspace_by_id(organization_id=organization_id, workspace_id=workspace_id) - except ServiceException as e: - LOGGER.error( - T("coal.cosmotech_api.workspace.not_found").format( - workspace_id=workspace_id, organization_id=organization_id - ) - ) - LOGGER.debug(e) - return None - solution_id = r_data.solution.solution_id - - api_sol = SolutionApi(api_client) - sol: Solution = api_sol.find_solution_by_id(organization_id=organization_id, solution_id=solution_id) - return sol diff --git a/cosmotech/coal/utils/decorator.py b/cosmotech/coal/utils/decorator.py index b436d153..7b25c3a4 100644 --- a/cosmotech/coal/utils/decorator.py +++ b/cosmotech/coal/utils/decorator.py @@ -19,5 +19,7 @@ def wrapper(*args, **kwargs): else: LOGGER.info(msg) return r + return wrapper + return decorator diff --git a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_run_template.py b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_run_template.py index 0d428cf1..d29d2144 100644 --- a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_run_template.py +++ b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_run_template.py @@ -7,8 +7,7 @@ import pathlib import pytest -from unittest.mock import MagicMock, patch, mock_open -from io import BytesIO +from unittest.mock import MagicMock, patch from zipfile import BadZipfile, ZipFile import cosmotech_api @@ -17,8 +16,10 @@ from cosmotech_api.exceptions import ServiceException from cosmotech.coal.cosmotech_api.run_template import load_run_template_handlers +from cosmotech.coal.utils.semver import semver_of +@pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") class TestRunTemplateFunctions: """Tests for top-level functions in the run_template module.""" @@ -66,7 +67,6 @@ def test_load_run_template_handlers_success(self): patch("cosmotech.coal.cosmotech_api.run_template.WorkspaceApi", return_value=mock_workspace_api), patch("cosmotech.coal.cosmotech_api.run_template.SolutionApi", return_value=mock_solution_api), patch("cosmotech.coal.cosmotech_api.run_template.ZipFile", return_value=mock_zipfile_context), - patch("cosmotech.coal.cosmotech_api.run_template.BytesIO") as mock_bytesio, patch("cosmotech.coal.cosmotech_api.run_template.pathlib.Path") as mock_path_class, ): mock_path_class.return_value = mock_path diff --git a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer.py b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer.py index 4c10ea97..7cd67c8a 100644 --- a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer.py +++ b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer.py @@ -14,22 +14,29 @@ import pytest import requests -from cosmotech_api import DatasetApi, RunnerApi, DatasetTwinGraphQuery - -from cosmotech.coal.cosmotech_api.twin_data_layer import ( - get_dataset_id_from_runner, - send_files_to_tdl, - load_files_from_tdl, - CSVSourceFile, - _process_csv_file, - _get_node_properties, - _get_relationship_properties, - _execute_queries, - _write_files, - ID_COLUMN, - SOURCE_COLUMN, - TARGET_COLUMN, -) + +from cosmotech.coal.utils.semver import semver_of +from cosmotech_api import DatasetApi, RunnerApi + +if semver_of("cosmotech_api").major < 5: + from cosmotech_api import DatasetTwinGraphQuery + + from cosmotech.coal.cosmotech_api.twin_data_layer import ( + get_dataset_id_from_runner, + send_files_to_tdl, + load_files_from_tdl, + CSVSourceFile, + _process_csv_file, + _get_node_properties, + _get_relationship_properties, + _execute_queries, + _write_files, + ID_COLUMN, + SOURCE_COLUMN, + TARGET_COLUMN, + ) + +pytestmark = pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported under version 5") class TestCSVSourceFile: diff --git a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_auth.py b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_auth.py index 1fd83118..44e1592f 100644 --- a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_auth.py +++ b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_auth.py @@ -14,18 +14,26 @@ import pytest import requests -from cosmotech_api import DatasetApi, RunnerApi, DatasetTwinGraphQuery + +from cosmotech.coal.utils.semver import semver_of +from cosmotech_api import DatasetApi, RunnerApi + +if semver_of("cosmotech_api").major < 5: + from cosmotech_api import DatasetTwinGraphQuery + from cosmotech.coal.cosmotech_api.twin_data_layer import ( + send_files_to_tdl, + load_files_from_tdl, + _process_csv_file, + _get_node_properties, + _get_relationship_properties, + ) from cosmotech.orchestrator.utils.translate import T -from cosmotech.coal.cosmotech_api.twin_data_layer import ( - send_files_to_tdl, - load_files_from_tdl, - _process_csv_file, - _get_node_properties, - _get_relationship_properties, -) + +skip_under_v5 = pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported under version 5") +@skip_under_v5 class TestTwinDataLayerAuth: """Tests for authentication in the twin_data_layer module.""" diff --git a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_coverage.py b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_coverage.py index ad98023e..efeb95a8 100644 --- a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_coverage.py +++ b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_coverage.py @@ -17,18 +17,26 @@ import pytest import requests -from cosmotech_api import DatasetApi, RunnerApi, DatasetTwinGraphQuery -from cosmotech.coal.cosmotech_api.twin_data_layer import ( - get_dataset_id_from_runner, - send_files_to_tdl, - load_files_from_tdl, - _process_csv_file, - _write_files, - BATCH_SIZE_LIMIT, -) +from cosmotech.coal.utils.semver import semver_of +from cosmotech_api import DatasetApi, RunnerApi +if semver_of("cosmotech_api").major < 5: + from cosmotech_api import DatasetTwinGraphQuery + from cosmotech.coal.cosmotech_api.twin_data_layer import ( + get_dataset_id_from_runner, + send_files_to_tdl, + load_files_from_tdl, + _process_csv_file, + _write_files, + BATCH_SIZE_LIMIT, + ) + +skip_under_v5 = pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported under version 5") + + +@skip_under_v5 class TestTwinDataLayerCoverage: """Additional tests for the twin_data_layer module to improve coverage.""" diff --git a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_edge_cases.py b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_edge_cases.py index 5140376b..5f0899e1 100644 --- a/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_edge_cases.py +++ b/tests/unit/coal/test_cosmotech_api/test_cosmotech_api_twin_data_layer_edge_cases.py @@ -14,18 +14,25 @@ import pytest import requests -from cosmotech_api import DatasetApi, RunnerApi, DatasetTwinGraphQuery +from cosmotech.coal.utils.semver import semver_of +from cosmotech_api import DatasetApi, RunnerApi + +if semver_of("cosmotech_api").major < 5: + from cosmotech_api import DatasetTwinGraphQuery + from cosmotech.coal.cosmotech_api.twin_data_layer import ( + send_files_to_tdl, + load_files_from_tdl, + _process_csv_file, + _get_node_properties, + _get_relationship_properties, + ) from cosmotech.orchestrator.utils.translate import T -from cosmotech.coal.cosmotech_api.twin_data_layer import ( - send_files_to_tdl, - load_files_from_tdl, - _process_csv_file, - _get_node_properties, - _get_relationship_properties, -) +skip_under_v5 = pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported under version 5") + +@skip_under_v5 class TestTwinDataLayerEdgeCases: """Tests for edge cases in the twin_data_layer module.""" diff --git a/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_common.py b/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_common.py index 1b7b9581..50684b1c 100644 --- a/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_common.py +++ b/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_common.py @@ -6,14 +6,16 @@ # specifically authorized by written means by Cosmo Tech. import pytest -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch from pathlib import Path from cosmotech_api import DatasetApi from cosmotech.coal.cosmotech_api.dataset.download.common import download_dataset_by_id +from cosmotech.coal.utils.semver import semver_of +@pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") class TestCommonFunctions: """Tests for top-level functions in the common module.""" diff --git a/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_twingraph.py b/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_twingraph.py index 0ce2972b..3ec3b283 100644 --- a/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_twingraph.py +++ b/tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_twingraph.py @@ -11,14 +11,20 @@ import pytest import cosmotech_api -from cosmotech_api import DatasetApi, TwingraphApi +from cosmotech_api import DatasetApi +from cosmotech.coal.utils.semver import semver_of -from cosmotech.coal.cosmotech_api.dataset.download.twingraph import ( - download_twingraph_dataset, - download_legacy_twingraph_dataset, -) +if semver_of("cosmotech_api").major < 5: + from cosmotech_api import TwingraphApi + from cosmotech.coal.cosmotech_api.dataset.download.twingraph import ( + download_twingraph_dataset, + download_legacy_twingraph_dataset, + ) +skip_under_v5 = pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported under version 5") + +@skip_under_v5 class TestTwingraphFunctions: """Tests for top-level functions in the twingraph module.""" diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets.py index 5e3690ec..b5173886 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets.py @@ -5,13 +5,10 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing -import tempfile from pathlib import Path from unittest.mock import MagicMock, patch, call - import pytest -from azure.identity import DefaultAzureCredential + from cosmotech_api import DatasetApi from cosmotech.coal.cosmotech_api.runner.datasets import ( @@ -22,6 +19,7 @@ download_datasets, dataset_to_file, ) +from cosmotech.coal.utils.semver import semver_of class TestDatasetsFunctions: @@ -60,6 +58,7 @@ def test_get_dataset_ids_from_runner(self): @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_adt_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_adt(self, mock_download_adt, mock_get_api_client): """Test the download_dataset function with ADT dataset.""" # Arrange @@ -106,6 +105,7 @@ def test_download_dataset_adt(self, mock_download_adt, mock_get_api_client): @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_legacy_twingraph_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_legacy_twingraph(self, mock_download_legacy, mock_get_api_client): """Test the download_dataset function with legacy twin graph dataset.""" # Arrange @@ -153,6 +153,7 @@ def test_download_dataset_legacy_twingraph(self, mock_download_legacy, mock_get_ @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_file_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_storage(self, mock_download_file, mock_get_api_client): """Test the download_dataset function with storage dataset.""" # Arrange @@ -205,6 +206,7 @@ def test_download_dataset_storage(self, mock_download_file, mock_get_api_client) @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_file_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_workspace_file(self, mock_download_file, mock_get_api_client): """Test the download_dataset function with workspace file dataset.""" # Arrange @@ -260,6 +262,7 @@ def test_download_dataset_workspace_file(self, mock_download_file, mock_get_api_ @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_twingraph_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_twingraph(self, mock_download_twingraph, mock_get_api_client): """Test the download_dataset function with twin graph dataset.""" # Arrange @@ -305,6 +308,56 @@ def test_download_dataset_twingraph(self, mock_download_twingraph, mock_get_api_ assert result["folder_path"] == str(mock_folder_path) assert result["dataset_id"] == dataset_id + @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") + @pytest.mark.skipif(semver_of("cosmotech_api").major < 5, reason="supported only in version 5") + def test_download_dataset_v5(self, mock_get_api_client): + """Test the download_dataset function with twin graph dataset.""" + # Arrange + organization_id = "org-123" + workspace_id = "ws-123" + dataset_id = "dataset-123" + dataset_part_id = "part-123" + + # Mock API client + mock_api_client = MagicMock() + mock_api_client.__enter__.return_value = mock_api_client + mock_get_api_client.return_value = (mock_api_client, "API Key") + + # Mock dataset API + mock_dataset_api = MagicMock(spec=DatasetApi) + mock_dataset_part = MagicMock() + mock_dataset_part.id = dataset_part_id + mock_dataset_part.source_name = "test-dataset-part.txt" + mock_dataset = MagicMock() + mock_dataset.id = dataset_id + mock_dataset.name = "test-dataset" + mock_dataset.parts = [mock_dataset_part] + mock_dataset_api.get_dataset.return_value = mock_dataset + + # Mock file part download + mock_content = b"test file part content in byte format" + mock_dataset_api.download_dataset_part.return_value = mock_content + + with patch("cosmotech.coal.cosmotech_api.runner.datasets.DatasetApi", return_value=mock_dataset_api): + # Act + result = download_dataset( + organization_id=organization_id, + workspace_id=workspace_id, + dataset_id=dataset_id, + ) + + # Assert + mock_dataset_api.get_dataset.assert_called_once_with( + organization_id=organization_id, workspace_id=workspace_id, dataset_id=dataset_id + ) + mock_dataset_api.download_dataset_part.assert_called_once_with( + organization_id, workspace_id, dataset_id, dataset_part_id + ) + assert result["type"] == "csm_dataset" + assert result["content"] == {"test-dataset-part.txt": "test file part content in byte format"} + assert result["name"] == "test-dataset" + assert result["dataset_id"] == "dataset-123" + @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_dataset") @patch("multiprocessing.Process") @patch("multiprocessing.Manager") @@ -392,14 +445,12 @@ def test_download_datasets_sequential(self, mock_get_api_client, mock_download_d workspace_id=workspace_id, dataset_id="dataset-1", read_files=True, - credentials=None, ), call( organization_id=organization_id, workspace_id=workspace_id, dataset_id="dataset-2", read_files=True, - credentials=None, ), ] ) @@ -436,7 +487,6 @@ def test_download_datasets_parallel_mode(self, mock_sequential, mock_parallel): workspace_id=workspace_id, dataset_ids=dataset_ids, read_files=True, - credentials=None, ) mock_sequential.assert_not_called() assert len(result) == 2 @@ -472,7 +522,6 @@ def test_download_datasets_sequential_mode(self, mock_sequential, mock_parallel) workspace_id=workspace_id, dataset_ids=dataset_ids, read_files=True, - credentials=None, ) mock_parallel.assert_not_called() assert len(result) == 2 @@ -507,7 +556,6 @@ def test_download_datasets_single_dataset(self, mock_sequential, mock_parallel): workspace_id=workspace_id, dataset_ids=dataset_ids, read_files=True, - credentials=None, ) mock_parallel.assert_not_called() assert len(result) == 1 diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_additional.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_additional.py index 2cb93863..9fdf9a57 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_additional.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_additional.py @@ -5,12 +5,10 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing from pathlib import Path -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch import pytest -from azure.identity import DefaultAzureCredential from cosmotech_api import DatasetApi from cosmotech.coal.cosmotech_api.runner.datasets import ( @@ -19,12 +17,14 @@ download_datasets, dataset_to_file, ) +from cosmotech.coal.utils.semver import semver_of class TestDatasetsAdditional: """Additional tests for the datasets module to improve coverage.""" @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_no_connector(self, mock_get_api_client): """Test the download_dataset function with a dataset that has no connector.""" # Arrange @@ -189,12 +189,11 @@ def test_download_datasets_parallel_start_join(self, mock_get_api_client, mock_m mock_process.side_effect = [mock_process_instance1, mock_process_instance2] # Act - with patch("cosmotech.coal.cosmotech_api.runner.datasets.download_dataset") as mock_download_dataset: - download_datasets_parallel( - organization_id=organization_id, - workspace_id=workspace_id, - dataset_ids=dataset_ids, - ) + download_datasets_parallel( + organization_id=organization_id, + workspace_id=workspace_id, + dataset_ids=dataset_ids, + ) # Assert # Check that start and join were called for each process diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_coverage.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_coverage.py index ebf50467..a62c9b52 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_coverage.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_coverage.py @@ -5,19 +5,13 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch import pytest -from azure.identity import DefaultAzureCredential from cosmotech.coal.cosmotech_api.runner.datasets import ( - download_dataset, download_datasets_parallel, download_datasets_sequential, - download_datasets, dataset_to_file, ) @@ -106,14 +100,12 @@ def test_download_datasets_sequential_with_error(self, mock_get_api_client, mock workspace_id=workspace_id, dataset_id="dataset-1", read_files=True, - credentials=None, ) mock_download_dataset.assert_any_call( organization_id=organization_id, workspace_id=workspace_id, dataset_id="dataset-2", read_files=True, - credentials=None, ) @patch("tempfile.mkdtemp") diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_edge_cases.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_edge_cases.py index 05a7297d..2b11cd46 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_edge_cases.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_edge_cases.py @@ -5,10 +5,8 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing -import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch, ANY import pytest from azure.identity import DefaultAzureCredential @@ -21,6 +19,7 @@ download_datasets, dataset_to_file, ) +from cosmotech.coal.utils.semver import semver_of class TestDatasetsEdgeCases: @@ -28,6 +27,7 @@ class TestDatasetsEdgeCases: @patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_adt_dataset") + @pytest.mark.skipif(semver_of("cosmotech_api").major >= 5, reason="not supported in version 5") def test_download_dataset_adt_pass_credentials(self, mock_download_adt, mock_get_api_client): """Test that download_dataset passes credentials to download_adt_dataset.""" # Arrange @@ -62,13 +62,12 @@ def test_download_dataset_adt_pass_credentials(self, mock_download_adt, mock_get organization_id=organization_id, workspace_id=workspace_id, dataset_id=dataset_id, - credentials=mock_credential, # Provide credentials ) # Assert mock_download_adt.assert_called_once_with( adt_address="https://adt.example.com", - credentials=mock_credential, + credentials=ANY, ) assert result["type"] == "adt" diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_final_coverage.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_final_coverage.py index 88113569..8b9f5e95 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_final_coverage.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_final_coverage.py @@ -5,22 +5,11 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch, call - -import pytest -from azure.identity import DefaultAzureCredential -from cosmotech_api import DatasetApi +from unittest.mock import MagicMock, patch from cosmotech.coal.cosmotech_api.runner.datasets import ( - download_dataset, - download_datasets_parallel, download_datasets_sequential, download_datasets, - dataset_to_file, - get_dataset_ids_from_runner, ) @@ -35,9 +24,6 @@ def test_download_datasets_sequential_pass_credentials(self, mock_download_datas workspace_id = "ws-123" dataset_ids = ["dataset-1", "dataset-2"] - # Mock credentials - mock_credentials = MagicMock(spec=DefaultAzureCredential) - # Mock download_dataset to return dataset info mock_download_dataset.side_effect = [ {"type": "csv", "content": {}, "name": "dataset-1"}, @@ -53,7 +39,6 @@ def test_download_datasets_sequential_pass_credentials(self, mock_download_datas organization_id=organization_id, workspace_id=workspace_id, dataset_ids=dataset_ids, - credentials=mock_credentials, ) # Assert @@ -68,7 +53,6 @@ def test_download_datasets_sequential_pass_credentials(self, mock_download_datas workspace_id=workspace_id, dataset_id=dataset_id, read_files=True, - credentials=mock_credentials, ) @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_datasets_parallel") @@ -102,7 +86,6 @@ def test_download_datasets_with_parallel_true(self, mock_sequential, mock_parall workspace_id=workspace_id, dataset_ids=dataset_ids, read_files=True, - credentials=None, ) mock_sequential.assert_not_called() @@ -137,6 +120,5 @@ def test_download_datasets_with_parallel_false(self, mock_sequential, mock_paral workspace_id=workspace_id, dataset_ids=dataset_ids, read_files=True, - credentials=None, ) mock_parallel.assert_not_called() diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_process.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_process.py index 8a913c95..93ba12ab 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_process.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_process.py @@ -5,11 +5,9 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import multiprocessing -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from azure.identity import DefaultAzureCredential from cosmotech.coal.cosmotech_api.runner.datasets import download_dataset_process @@ -25,7 +23,6 @@ def test_download_dataset_process_success(self, mock_download_dataset): organization_id = "org-123" workspace_id = "ws-123" read_files = True - credentials = None # Create shared dictionaries return_dict = {} @@ -42,50 +39,7 @@ def test_download_dataset_process_success(self, mock_download_dataset): mock_download_dataset.return_value = mock_dataset_info # Act - download_dataset_process( - dataset_id, organization_id, workspace_id, read_files, credentials, return_dict, error_dict - ) - - # Assert - mock_download_dataset.assert_called_once_with( - organization_id=organization_id, - workspace_id=workspace_id, - dataset_id=dataset_id, - read_files=read_files, - credentials=credentials, - ) - assert dataset_id in return_dict - assert return_dict[dataset_id] == mock_dataset_info - assert len(error_dict) == 0 - - @patch("cosmotech.coal.cosmotech_api.runner.datasets.download_dataset") - def test_download_dataset_process_with_credentials(self, mock_download_dataset): - """Test the download_dataset_process function with credentials.""" - # Arrange - dataset_id = "dataset-123" - organization_id = "org-123" - workspace_id = "ws-123" - read_files = True - credentials = MagicMock(spec=DefaultAzureCredential) - - # Create shared dictionaries - return_dict = {} - error_dict = {} - - # Mock download_dataset to return dataset info - mock_dataset_info = { - "type": "adt", - "content": {"nodes": [], "edges": []}, - "name": "test-dataset", - "folder_path": "/tmp/dataset", - "dataset_id": dataset_id, - } - mock_download_dataset.return_value = mock_dataset_info - - # Act - download_dataset_process( - dataset_id, organization_id, workspace_id, read_files, credentials, return_dict, error_dict - ) + download_dataset_process(dataset_id, organization_id, workspace_id, read_files, return_dict, error_dict) # Assert mock_download_dataset.assert_called_once_with( @@ -93,7 +47,6 @@ def test_download_dataset_process_with_credentials(self, mock_download_dataset): workspace_id=workspace_id, dataset_id=dataset_id, read_files=read_files, - credentials=credentials, ) assert dataset_id in return_dict assert return_dict[dataset_id] == mock_dataset_info @@ -107,7 +60,6 @@ def test_download_dataset_process_error(self, mock_download_dataset): organization_id = "org-123" workspace_id = "ws-123" read_files = True - credentials = None # Create shared dictionaries return_dict = {} @@ -119,9 +71,7 @@ def test_download_dataset_process_error(self, mock_download_dataset): # Act & Assert with pytest.raises(ValueError) as excinfo: - download_dataset_process( - dataset_id, organization_id, workspace_id, read_files, credentials, return_dict, error_dict - ) + download_dataset_process(dataset_id, organization_id, workspace_id, read_files, return_dict, error_dict) # Verify the error was re-raised assert str(excinfo.value) == "Failed to download dataset" diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download.py index 7562d523..51d46851 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download.py @@ -5,24 +5,13 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import os -import pathlib -import shutil -import tempfile -from unittest.mock import MagicMock, patch, call - -import pytest -from azure.identity import DefaultAzureCredential -from cosmotech_api import RunnerApi, ScenarioApi -from cosmotech_api.exceptions import ApiException - +from unittest.mock import MagicMock, patch from cosmotech.coal.cosmotech_api.runner.download import download_runner_data class TestDownloadFunctions: """Tests for top-level functions in the download module.""" - @patch("cosmotech.coal.cosmotech_api.runner.download.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.download.get_runner_data") @patch("cosmotech.coal.cosmotech_api.runner.download.format_parameters_list") @patch("cosmotech.coal.cosmotech_api.runner.download.write_parameters") @@ -41,7 +30,6 @@ def test_download_runner_data_with_datasets( mock_write_parameters, mock_format_parameters, mock_get_runner_data, - mock_get_api_client, ): """Test the download_runner_data function with datasets.""" # Arrange @@ -51,11 +39,6 @@ def test_download_runner_data_with_datasets( parameter_folder = "/tmp/params" dataset_folder = "/tmp/datasets" - # Mock API client - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "API Key") - # Mock runner data mock_runner_data = MagicMock() mock_runner_data.dataset_list = ["dataset-1", "dataset-2"] @@ -106,7 +89,6 @@ def test_download_runner_data_with_datasets( dataset_ids=["dataset-1", "dataset-2", "dataset-3"], read_files=False, parallel=True, - credentials=None, ) # The dataset_to_file function is called for each dataset in the dataset_list (2) and for the dataset referenced by a parameter (1) assert mock_dataset_to_file.call_count == 3 @@ -117,9 +99,8 @@ def test_download_runner_data_with_datasets( assert result["datasets"] == mock_datasets assert result["parameters"] == {"param1": "dataset-3", "param2": "value1"} - @patch("cosmotech.coal.cosmotech_api.runner.download.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.download.get_runner_data") - def test_download_runner_data_no_parameters(self, mock_get_runner_data, mock_get_api_client): + def test_download_runner_data_no_parameters(self, mock_get_runner_data): """Test the download_runner_data function with no parameters.""" # Arrange organization_id = "org-123" @@ -127,11 +108,6 @@ def test_download_runner_data_no_parameters(self, mock_get_runner_data, mock_get runner_id = "runner-123" parameter_folder = "/tmp/params" - # Mock API client - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "API Key") - # Mock runner data with no parameters mock_runner_data = MagicMock() mock_runner_data.parameters_values = None @@ -152,7 +128,6 @@ def test_download_runner_data_no_parameters(self, mock_get_runner_data, mock_get assert result["datasets"] == {} assert result["parameters"] == {} - @patch("cosmotech.coal.cosmotech_api.runner.download.get_api_client") @patch("cosmotech.coal.cosmotech_api.runner.download.get_runner_data") @patch("cosmotech.coal.cosmotech_api.runner.download.format_parameters_list") @patch("cosmotech.coal.cosmotech_api.runner.download.write_parameters") @@ -161,7 +136,6 @@ def test_download_runner_data_no_datasets( mock_write_parameters, mock_format_parameters, mock_get_runner_data, - mock_get_api_client, ): """Test the download_runner_data function without datasets.""" # Arrange @@ -170,11 +144,6 @@ def test_download_runner_data_no_datasets( runner_id = "runner-123" parameter_folder = "/tmp/params" - # Mock API client - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "API Key") - # Mock runner data mock_runner_data = MagicMock() mock_runner_data.dataset_list = [] diff --git a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download_edge_cases.py b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download_edge_cases.py index ee4e8dc7..f6c4c8f8 100644 --- a/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download_edge_cases.py +++ b/tests/unit/coal/test_cosmotech_api/test_runner/test_runner_download_edge_cases.py @@ -5,16 +5,7 @@ # etc., to any person is prohibited unless it has been previously and # specifically authorized by written means by Cosmo Tech. -import os -import pathlib -import shutil -import tempfile -from unittest.mock import MagicMock, patch, call - -import pytest -from azure.identity import DefaultAzureCredential -from cosmotech_api import RunnerApi, ScenarioApi -from cosmotech_api.exceptions import ApiException +from unittest.mock import MagicMock, patch from cosmotech.coal.cosmotech_api.runner.download import download_runner_data @@ -22,8 +13,6 @@ class TestDownloadEdgeCases: """Tests for edge cases in the download module.""" - @patch("cosmotech.coal.cosmotech_api.runner.download.get_api_client") - @patch("cosmotech.coal.cosmotech_api.runner.download.DefaultAzureCredential") @patch("cosmotech.coal.cosmotech_api.runner.download.get_runner_data") @patch("cosmotech.coal.cosmotech_api.runner.download.format_parameters_list") @patch("cosmotech.coal.cosmotech_api.runner.download.write_parameters") @@ -36,8 +25,6 @@ def test_download_runner_data_azure_credentials( mock_write_parameters, mock_format_parameters, mock_get_runner_data, - mock_default_credential, - mock_get_api_client, ): """Test the download_runner_data function with Azure credentials.""" # Arrange @@ -46,15 +33,6 @@ def test_download_runner_data_azure_credentials( runner_id = "runner-123" parameter_folder = "/tmp/params" - # Mock API client with Azure Entra Connection - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "Azure Entra Connection") - - # Mock DefaultAzureCredential - mock_credential = MagicMock(spec=DefaultAzureCredential) - mock_default_credential.return_value = mock_credential - # Mock runner data mock_runner_data = MagicMock() mock_runner_data.dataset_list = ["dataset-1"] @@ -88,12 +66,10 @@ def test_download_runner_data_azure_credentials( ) # Assert - mock_default_credential.assert_called_once() mock_download_datasets.assert_called_once_with( organization_id=organization_id, workspace_id=workspace_id, dataset_ids=["dataset-1"], read_files=False, parallel=True, - credentials=mock_credential, ) diff --git a/tests/unit/coal/test_utils/test_utils_api.py b/tests/unit/coal/test_utils/test_utils_api.py deleted file mode 100644 index 2f584d73..00000000 --- a/tests/unit/coal/test_utils/test_utils_api.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (C) - 2023 - 2025 - Cosmo Tech -# This document and all information contained herein is the exclusive property - -# including all intellectual property rights pertaining thereto - of Cosmo Tech. -# Any use, reproduction, translation, broadcasting, transmission, distribution, -# etc., to any person is prohibited unless it has been previously and -# specifically authorized by written means by Cosmo Tech. - -import json -import pathlib -from unittest.mock import MagicMock, patch, mock_open - -import pytest -import yaml -from cosmotech_api import Solution, Workspace -from cosmotech_api.exceptions import ServiceException - -# Mock the dependencies to avoid circular imports -import sys - -sys.modules["cosmotech.coal.cosmotech_api.connection"] = MagicMock() -sys.modules["cosmotech.coal.cosmotech_api.orchestrator"] = MagicMock() - -# Now we can import the functions -from cosmotech.coal.utils.api import read_solution_file, get_solution - - -class TestApiFunctions: - """Tests for top-level functions in the api module.""" - - @patch("pathlib.Path") - @patch("json.load") - def test_read_solution_file_json(self, mock_json_load, mock_path_class): - """Test the read_solution_file function with a JSON file.""" - # Arrange - solution_file = "solution.json" - solution_content = { - "name": "Test Solution", - "version": "1.0.0", - "parameters": [{"id": "param1", "name": "Parameter 1"}], - } - mock_json_load.return_value = solution_content - - # Mock Path instance - mock_path = MagicMock() - mock_path.suffix = ".json" - mock_path.open.return_value.__enter__.return_value = MagicMock() - mock_path_class.return_value = mock_path - - # Act - result = read_solution_file(solution_file) - - # Assert - mock_path.open.assert_called_once() - mock_json_load.assert_called_once() - assert result is not None - assert result.name == "Test Solution" - assert result.version == "1.0.0" - assert len(result.parameters) == 1 - assert result.parameters[0].id == "param1" - - @patch("pathlib.Path") - @patch("yaml.safe_load") - def test_read_solution_file_yaml(self, mock_yaml_load, mock_path_class): - """Test the read_solution_file function with a YAML file.""" - # Arrange - solution_file = "solution.yaml" - solution_content = { - "name": "Test Solution", - "version": "1.0.0", - "parameters": [{"id": "param1", "name": "Parameter 1"}], - } - mock_yaml_load.return_value = solution_content - - # Mock Path instance - mock_path = MagicMock() - mock_path.suffix = ".yaml" - mock_path.open.return_value.__enter__.return_value = MagicMock() - mock_path_class.return_value = mock_path - - # Act - result = read_solution_file(solution_file) - - # Assert - mock_path.open.assert_called_once() - mock_yaml_load.assert_called_once() - assert result is not None - assert result.name == "Test Solution" - assert result.version == "1.0.0" - assert len(result.parameters) == 1 - assert result.parameters[0].id == "param1" - - @patch("pathlib.Path") - def test_read_solution_file_invalid_extension(self, mock_path_class): - """Test the read_solution_file function with an invalid file extension.""" - # Arrange - solution_file = "solution.txt" - - # Mock Path instance - mock_path = MagicMock() - mock_path.suffix = ".txt" - mock_path_class.return_value = mock_path - - # Act - result = read_solution_file(solution_file) - - # Assert - assert result is None - - @patch("cosmotech.coal.utils.api.get_api_client") - def test_get_solution_success(self, mock_get_api_client): - """Test the get_solution function with successful API calls.""" - # Arrange - organization_id = "org-123" - workspace_id = "ws-123" - solution_id = "sol-123" - - # Mock API client - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "API Key") - - # Mock workspace API - mock_workspace_api = MagicMock() - mock_workspace = MagicMock() - # Create a solution attribute with a solution_id - mock_solution = MagicMock() - mock_solution.solution_id = solution_id - mock_workspace.solution = mock_solution - mock_workspace_api.find_workspace_by_id.return_value = mock_workspace - - # Mock solution API - mock_solution_api = MagicMock() - mock_solution = MagicMock(spec=Solution) - mock_solution.name = "Test Solution" - mock_solution_api.find_solution_by_id.return_value = mock_solution - - with patch("cosmotech.coal.utils.api.WorkspaceApi", return_value=mock_workspace_api): - with patch("cosmotech.coal.utils.api.SolutionApi", return_value=mock_solution_api): - # Act - result = get_solution(organization_id, workspace_id) - - # Assert - mock_workspace_api.find_workspace_by_id.assert_called_once_with( - organization_id=organization_id, workspace_id=workspace_id - ) - mock_solution_api.find_solution_by_id.assert_called_once_with( - organization_id=organization_id, solution_id=solution_id - ) - assert result == mock_solution - - @patch("cosmotech.coal.utils.api.get_api_client") - def test_get_solution_workspace_not_found(self, mock_get_api_client): - """Test the get_solution function when workspace is not found.""" - # Arrange - organization_id = "org-123" - workspace_id = "ws-123" - - # Mock API client - mock_api_client = MagicMock() - mock_api_client.__enter__.return_value = mock_api_client - mock_get_api_client.return_value = (mock_api_client, "API Key") - - # Mock workspace API to raise exception - mock_workspace_api = MagicMock() - mock_workspace_api.find_workspace_by_id.side_effect = ServiceException( - status=404, reason="Not Found", body="Workspace not found" - ) - - with patch("cosmotech.coal.utils.api.WorkspaceApi", return_value=mock_workspace_api): - # Act - result = get_solution(organization_id, workspace_id) - - # Assert - mock_workspace_api.find_workspace_by_id.assert_called_once_with( - organization_id=organization_id, workspace_id=workspace_id - ) - assert result is None