Skip to content

Commit 2c9447e

Browse files
clean tests and make sure all test code is executed (TGSAI#685)
* clean tests * pre-commit added a space * Remove DEBUG_MODE * Fix pre-commit issues * pre-commit --------- Co-authored-by: Altay Sansal <[email protected]>
1 parent 92bccd8 commit 2c9447e

File tree

8 files changed

+49
-164
lines changed

8 files changed

+49
-164
lines changed

tests/conftest.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
import warnings
6-
from pathlib import Path
6+
from typing import TYPE_CHECKING
77
from urllib.request import urlretrieve
88

99
import pytest
1010

11-
DEBUG_MODE = False
11+
if TYPE_CHECKING:
12+
from pathlib import Path
1213

1314
# Suppress Dask's chunk balancing warning
1415
warnings.filterwarnings(
@@ -22,10 +23,6 @@
2223
@pytest.fixture(scope="session")
2324
def fake_segy_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
2425
"""Make a temp file for the fake SEG-Y files we are going to create."""
25-
if DEBUG_MODE:
26-
tmp_dir = Path("tmp/fake_segy")
27-
tmp_dir.mkdir(parents=True, exist_ok=True)
28-
return tmp_dir
2926
return tmp_path_factory.mktemp(r"fake_segy")
3027

3128

@@ -38,11 +35,7 @@ def segy_input_uri() -> str:
3835
@pytest.fixture(scope="session")
3936
def segy_input(segy_input_uri: str, tmp_path_factory: pytest.TempPathFactory) -> Path:
4037
"""Download teapot dome dataset for testing."""
41-
if DEBUG_MODE:
42-
tmp_dir = Path("tmp/segy")
43-
tmp_dir.mkdir(parents=True, exist_ok=True)
44-
else:
45-
tmp_dir = tmp_path_factory.mktemp("segy")
38+
tmp_dir = tmp_path_factory.mktemp("segy")
4639
tmp_file = tmp_dir / "teapot.segy"
4740
urlretrieve(segy_input_uri, tmp_file) # noqa: S310
4841
return tmp_file
@@ -51,25 +44,17 @@ def segy_input(segy_input_uri: str, tmp_path_factory: pytest.TempPathFactory) ->
5144
@pytest.fixture(scope="module")
5245
def zarr_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
5346
"""Make a temp file for the output MDIO."""
54-
if DEBUG_MODE:
55-
return Path("tmp/mdio")
5647
return tmp_path_factory.mktemp(r"mdio")
5748

5849

5950
@pytest.fixture(scope="module")
60-
def zarr_tmp2(tmp_path_factory: pytest.TempPathFactory) -> Path:
51+
def zarr_tmp2(tmp_path_factory: pytest.TempPathFactory) -> Path: # pragma: no cover - used by disabled test
6152
"""Make a temp file for the output MDIO."""
62-
if DEBUG_MODE:
63-
return Path("tmp/mdio2")
6453
return tmp_path_factory.mktemp(r"mdio2")
6554

6655

6756
@pytest.fixture(scope="session")
6857
def segy_export_tmp(tmp_path_factory: pytest.TempPathFactory) -> Path:
6958
"""Make a temp file for the round-trip IBM SEG-Y."""
70-
if DEBUG_MODE:
71-
tmp_dir = Path("tmp/segy")
72-
tmp_dir.mkdir(parents=True, exist_ok=True)
73-
else:
74-
tmp_dir = tmp_path_factory.mktemp("segy")
59+
tmp_dir = tmp_path_factory.mktemp("segy")
7560
return tmp_dir / "teapot_roundtrip.segy"

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
3434
@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
3535
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
36-
class TestImport4DNonReg:
36+
class TestImport4DNonReg: # pragma: no cover - tests is skipped
3737
"""Test for 4D segy import with grid overrides."""
3838

3939
def test_import_4d_segy( # noqa: PLR0913
@@ -161,7 +161,7 @@ def test_import_4d_segy( # noqa: PLR0913
161161
@pytest.mark.skip(reason="AutoShotWrap requires a template that is not implemented yet.")
162162
@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, {"AutoShotWrap": True}, None])
163163
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.A, StreamerShotGeometryType.B])
164-
class TestImport6D:
164+
class TestImport6D: # pragma: no cover - tests is skipped
165165
"""Test for 6D segy import with grid overrides."""
166166

167167
def test_import_6d_segy( # noqa: PLR0913

tests/integration/test_segy_import_export_masked.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from segy.schema import HeaderField
2222
from segy.schema import SegySpec
2323
from segy.standards import get_segy_standard
24-
from tests.conftest import DEBUG_MODE
2524

2625
from mdio import mdio_to_segy
2726
from mdio.api.io import open_mdio
@@ -241,17 +240,6 @@ def mock_nd_segy(path: str, grid_conf: GridConfig, segy_factory_conf: SegyFactor
241240
headers["group_coord_y"] = (y_origin + dim_grids["shot_point"] * y_step + dim_grids[cable_key] * y_step).ravel()
242241
headers["gun"] = np.tile((1, 2, 3), num_traces // 3)
243242

244-
# for field in ["cdp_x", "source_coord_x", "group_coord_x"]:
245-
# start = 700000
246-
# step = 100
247-
# stop = start + step * (trace_numbers.size - 0)
248-
# headers[field] = np.arange(start=start, stop=stop, step=step)
249-
# for field in ["cdp_y", "source_coord_y", "group_coord_y"]:
250-
# start = 4000000
251-
# step = 100
252-
# stop = start + step * (trace_numbers.size - 0)
253-
# headers[field] = np.arange(start=start, stop=stop, step=step)
254-
255243
samples[:] = trace_numbers[..., None]
256244

257245
with fsspec.open(path, mode="wb") as fp:
@@ -289,8 +277,6 @@ def export_masked_path(tmp_path_factory: pytest.TempPathFactory, raw_headers_env
289277
raw_headers_enabled = os.getenv("MDIO__IMPORT__RAW_HEADERS") in ("1", "true", "yes", "on")
290278
path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"
291279

292-
if DEBUG_MODE:
293-
return Path(f"tmp/export_masked_{path_suffix}")
294280
return tmp_path_factory.getbasetemp() / f"export_masked_{path_suffix}"
295281

296282

tests/test_main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def runner() -> CliRunner:
1919
# https://github.com/TGSAI/mdio-python/issues/646
2020
@pytest.mark.skip(reason="CLI hasn't been updated to work with v1 yet.")
2121
@pytest.mark.dependency
22-
def test_main_succeeds(runner: CliRunner, segy_input: Path, zarr_tmp: Path) -> None:
22+
def test_main_succeeds(
23+
runner: CliRunner, segy_input: Path, zarr_tmp: Path
24+
) -> None: # pragma: no cover - test is skipped
2325
"""It exits with a status code of zero."""
2426
cli_args = ["segy", "import", str(segy_input), str(zarr_tmp)]
2527
cli_args.extend(["--header-locations", "181,185"])
@@ -29,8 +31,11 @@ def test_main_succeeds(runner: CliRunner, segy_input: Path, zarr_tmp: Path) -> N
2931
assert result.exit_code == 0
3032

3133

34+
@pytest.mark.skip(reason="CLI hasn't been updated to work with v1 yet.")
3235
@pytest.mark.dependency(depends=["test_main_succeeds"])
33-
def test_main_cloud(runner: CliRunner, segy_input_uri: str, zarr_tmp: Path) -> None:
36+
def test_main_cloud(
37+
runner: CliRunner, segy_input_uri: str, zarr_tmp: Path
38+
) -> None: # pragma: no cover - tests is skipped
3439
"""It exits with a status code of zero."""
3540
os.environ["MDIO__IMPORT__CLOUD_NATIVE"] = "true"
3641
cli_args = ["segy", "import", segy_input_uri, str(zarr_tmp)]
@@ -42,8 +47,9 @@ def test_main_cloud(runner: CliRunner, segy_input_uri: str, zarr_tmp: Path) -> N
4247
assert result.exit_code == 0
4348

4449

50+
@pytest.mark.skip(reason="CLI hasn't been updated to work with v1 yet.")
4551
@pytest.mark.dependency(depends=["test_main_succeeds"])
46-
def test_main_info_succeeds(runner: CliRunner, zarr_tmp: Path) -> None:
52+
def test_main_info_succeeds(runner: CliRunner, zarr_tmp: Path) -> None: # pragma: no cover - tests is skipped
4753
"""It exits with a status code of zero."""
4854
cli_args = ["info"]
4955
cli_args.extend([str(zarr_tmp)])
@@ -52,8 +58,9 @@ def test_main_info_succeeds(runner: CliRunner, zarr_tmp: Path) -> None:
5258
assert result.exit_code == 0
5359

5460

61+
@pytest.mark.skip(reason="CLI hasn't been updated to work with v1 yet.")
5562
@pytest.mark.dependency(depends=["test_main_succeeds"])
56-
def test_main_copy(runner: CliRunner, zarr_tmp: Path, zarr_tmp2: Path) -> None:
63+
def test_main_copy(runner: CliRunner, zarr_tmp: Path, zarr_tmp2: Path) -> None: # pragma: no cover - tests is skipped
5764
"""It exits with a status code of zero."""
5865
cli_args = ["copy", str(zarr_tmp), str(zarr_tmp2), "-headers", "-traces"]
5966

tests/unit/conftest.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

tests/unit/v1/helpers.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Helper methods used in unit tests."""
22

3-
from pathlib import Path
4-
53
from mdio.builder.dataset_builder import MDIODatasetBuilder
64
from mdio.builder.dataset_builder import _BuilderState
75
from mdio.builder.dataset_builder import _get_named_dimension
@@ -66,7 +64,7 @@ def validate_variable(
6664
elif isinstance(container, Dataset):
6765
var_list = container.variables
6866
global_coord_list = _get_all_coordinates(container)
69-
else:
67+
else: # pragma: no cover
7068
err_msg = f"Expected MDIODatasetBuilder or Dataset, got {type(container)}"
7169
raise TypeError(err_msg)
7270

@@ -105,7 +103,7 @@ def _get_coordinate(
105103
in the global coordinate list.
106104
If the coordinate is stored as a Coordinate object, it is returned directly.
107105
"""
108-
if coordinates_or_references is None:
106+
if coordinates_or_references is None: # pragma: no cover
109107
return None
110108

111109
for c in coordinates_or_references:
@@ -115,7 +113,7 @@ def _get_coordinate(
115113
# Find the Coordinate in the global list and return it.
116114
if global_coord_list is not None:
117115
cc = next((cc for cc in global_coord_list if cc.name == name), None)
118-
if cc is None:
116+
if cc is None: # pragma: no cover
119117
msg = f"Pre-existing coordinate named {name!r} is not found"
120118
raise ValueError(msg)
121119
return cc
@@ -124,7 +122,7 @@ def _get_coordinate(
124122
# Return it.
125123
return c
126124

127-
return None
125+
return None # pragma: no cover
128126

129127

130128
def _get_all_coordinates(dataset: Dataset) -> list[Coordinate]:
@@ -138,19 +136,6 @@ def _get_all_coordinates(dataset: Dataset) -> list[Coordinate]:
138136
return list(all_coords.values())
139137

140138

141-
def output_path(file_dir: Path, file_name: str, debugging: bool = False) -> Path:
142-
"""Generate the output path for the test file-system output.
143-
144-
Note:
145-
Use debugging=True, if you need to retain the created files for debugging
146-
purposes. Otherwise, the files will be created in-memory and not saved to disk.
147-
"""
148-
if debugging:
149-
return file_dir / f"mdio-tests/{file_name}.zarr"
150-
151-
return file_dir / f"{file_name}.zarr"
152-
153-
154139
def make_seismic_poststack_3d_acceptance_dataset(dataset_name: str) -> Dataset:
155140
"""Create in-memory Seismic PostStack 3D Acceptance dataset."""
156141
ds = MDIODatasetBuilder(

tests/unit/v1/templates/test_template_registry.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ def _name(self) -> str:
4646
return self.template_name
4747

4848
def _load_dataset_attributes(self) -> None:
49-
return None # Mock implementation
50-
51-
def create_dataset(self) -> str:
52-
"""Create a mock dataset.
53-
54-
Returns:
55-
str: A message indicating the dataset creation.
56-
"""
57-
return f"Mock dataset created by {self.template_name}"
49+
return None # pragma: no cover - Mock implementation
5850

5951

6052
def _assert_default_templates(template_names: list[str]) -> None:
@@ -92,12 +84,9 @@ def test_singleton_thread_safety(self) -> None:
9284
errors = []
9385

9486
def create_instance() -> None:
95-
try:
96-
instance = TemplateRegistry()
97-
instances.append(instance)
98-
time.sleep(0.001) # Small delay to increase contention
99-
except Exception as e:
100-
errors.append(e)
87+
instance = TemplateRegistry()
88+
instances.append(instance)
89+
time.sleep(0.001) # Small delay to increase contention
10190

10291
# Create multiple threads trying to create instances
10392
threads = [threading.Thread(target=create_instance) for _ in range(10)]
@@ -311,13 +300,10 @@ def test_concurrent_registration(self) -> None:
311300
errors = []
312301

313302
def register_template_worker(template_id: int) -> None:
314-
try:
315-
template = MockDatasetTemplate(f"template_{template_id}")
316-
name = registry.register(template)
317-
results.append((template_id, name))
318-
time.sleep(0.001) # Small delay
319-
except Exception as e:
320-
errors.append((template_id, e))
303+
template = MockDatasetTemplate(f"template_{template_id}")
304+
name = registry.register(template)
305+
results.append((template_id, name))
306+
time.sleep(0.001) # Small delay
321307

322308
# Create multiple threads registering different templates
323309
threads = [threading.Thread(target=register_template_worker, args=(i,)) for i in range(10)]
@@ -351,28 +337,24 @@ def test_concurrent_access_mixed_operations(self) -> None:
351337
errors = []
352338

353339
def mixed_operations_worker(worker_id: int) -> None:
354-
try:
355-
operations_results = []
356-
357-
# Get existing template
358-
if worker_id % 2 == 0:
359-
template = registry.get("initial_0")
360-
operations_results.append(("get", template.template_name))
340+
operations_results = []
361341

362-
# Register new template
363-
if worker_id % 3 == 0:
364-
new_template = MockDatasetTemplate(f"worker_{worker_id}")
365-
name = registry.register(new_template)
366-
operations_results.append(("register", name))
342+
# Get existing template
343+
if worker_id % 2 == 0:
344+
template = registry.get("initial_0")
345+
operations_results.append(("get", template.template_name))
367346

368-
# List templates
369-
templates = registry.list_all_templates()
370-
operations_results.append(("list", len(templates)))
347+
# Register new template
348+
if worker_id % 3 == 0:
349+
new_template = MockDatasetTemplate(f"worker_{worker_id}")
350+
name = registry.register(new_template)
351+
operations_results.append(("register", name))
371352

372-
results.append((worker_id, operations_results))
353+
# List templates
354+
templates = registry.list_all_templates()
355+
operations_results.append(("list", len(templates)))
373356

374-
except Exception as e:
375-
errors.append((worker_id, e))
357+
results.append((worker_id, operations_results))
376358

377359
# Run concurrent operations
378360
threads = [threading.Thread(target=mixed_operations_worker, args=(i,)) for i in range(15)]

0 commit comments

Comments
 (0)