Skip to content

Commit 582606a

Browse files
committed
add test for new download v5
1 parent d1409ce commit 582606a

File tree

3 files changed

+61
-7
lines changed

3 files changed

+61
-7
lines changed

tests/unit/coal/test_cosmotech_api/test_dataset/test_download/test_download_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@
66
# specifically authorized by written means by Cosmo Tech.
77

88
import pytest
9-
from unittest.mock import MagicMock, patch, call
9+
from unittest.mock import MagicMock, patch
1010
from pathlib import Path
1111

1212
from cosmotech_api import DatasetApi
1313

1414
from cosmotech.coal.cosmotech_api.dataset.download.common import download_dataset_by_id
15+
from cosmotech.coal.utils.semver import semver_of
1516

1617

18+
@pytest.mark.skipif(
19+
semver_of('cosmotech_api').major >= 5, reason='not supported in version 5'
20+
)
1721
class TestCommonFunctions:
1822
"""Tests for top-level functions in the common module."""
1923

tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,64 @@ def test_download_dataset_twingraph(self, mock_download_twingraph, mock_get_api_
318318
assert result["folder_path"] == str(mock_folder_path)
319319
assert result["dataset_id"] == dataset_id
320320

321+
@patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client")
322+
@pytest.mark.skipif(
323+
semver_of('cosmotech_api').major < 5, reason='supported only in version 5'
324+
)
325+
def test_download_dataset_v5(self, mock_get_api_client):
326+
"""Test the download_dataset function with twin graph dataset."""
327+
# Arrange
328+
organization_id = "org-123"
329+
workspace_id = "ws-123"
330+
dataset_id = "dataset-123"
331+
dataset_part_id = "part-123"
332+
333+
# Mock API client
334+
mock_api_client = MagicMock()
335+
mock_api_client.__enter__.return_value = mock_api_client
336+
mock_get_api_client.return_value = (mock_api_client, "API Key")
337+
338+
# Mock dataset API
339+
mock_dataset_api = MagicMock(spec=DatasetApi)
340+
mock_dataset_part = MagicMock()
341+
mock_dataset_part.id = dataset_part_id
342+
mock_dataset_part.source_name = "test-dataset-part.txt"
343+
mock_dataset = MagicMock()
344+
mock_dataset.id = dataset_id
345+
mock_dataset.name = "test-dataset"
346+
mock_dataset.parts = [mock_dataset_part]
347+
mock_dataset_api.get_dataset.return_value = mock_dataset
348+
349+
# Mock file part download
350+
mock_content = b'test file part content in byte format'
351+
mock_dataset_api.download_dataset_part.return_value = mock_content
352+
353+
with patch("cosmotech.coal.cosmotech_api.runner.datasets.DatasetApi", return_value=mock_dataset_api):
354+
# Act
355+
result = download_dataset(
356+
organization_id=organization_id,
357+
workspace_id=workspace_id,
358+
dataset_id=dataset_id,
359+
)
360+
361+
# Assert
362+
mock_dataset_api.get_dataset.assert_called_once_with(
363+
organization_id=organization_id, workspace_id=workspace_id, dataset_id=dataset_id
364+
)
365+
mock_dataset_api.download_dataset_part.assert_called_once_with(
366+
organization_id,
367+
workspace_id,
368+
dataset_id,
369+
dataset_part_id)
370+
assert result["type"] == "csm_dataset"
371+
assert result["content"] == {'test-dataset-part.txt': 'test file part content in byte format'}
372+
assert result["name"] == "test-dataset"
373+
assert result["dataset_id"] == "dataset-123"
374+
321375
@patch("cosmotech.coal.cosmotech_api.runner.datasets.download_dataset")
322376
@patch("multiprocessing.Process")
323377
@patch("multiprocessing.Manager")
324378
@patch("cosmotech.coal.cosmotech_api.runner.datasets.get_api_client")
325-
@pytest.mark.skipif(
326-
semver_of('cosmotech_api').major >= 5, reason='not supported in version 5'
327-
)
328379
def test_download_datasets_parallel(self, mock_get_api_client, mock_manager, mock_process, mock_download_dataset):
329380
"""Test the download_datasets_parallel function."""
330381
# Arrange

tests/unit/coal/test_cosmotech_api/test_runner/test_runner_datasets_edge_cases.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# specifically authorized by written means by Cosmo Tech.
77

88
from pathlib import Path
9-
from unittest.mock import MagicMock, patch
9+
from unittest.mock import MagicMock, patch, ANY
1010

1111
import pytest
1212
from azure.identity import DefaultAzureCredential
@@ -64,13 +64,12 @@ def test_download_dataset_adt_pass_credentials(self, mock_download_adt, mock_get
6464
organization_id=organization_id,
6565
workspace_id=workspace_id,
6666
dataset_id=dataset_id,
67-
credentials=mock_credential, # Provide credentials
6867
)
6968

7069
# Assert
7170
mock_download_adt.assert_called_once_with(
7271
adt_address="https://adt.example.com",
73-
credentials=mock_credential,
72+
credentials=ANY,
7473
)
7574
assert result["type"] == "adt"
7675

0 commit comments

Comments
 (0)