Skip to content

Commit c6d2312

Browse files
committed
Manual reconciliation of merging v1_backup into v1 branch
1 parent 141e42b commit c6d2312

29 files changed

+361
-290
lines changed

src/mdio/__main__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from importlib import metadata
88
from pathlib import Path
99

10+
from typing import Any
11+
from typing import Callable
12+
1013
import click
1114

1215

@@ -30,13 +33,13 @@ class MyCLI(click.MultiCommand):
3033
- plugin_folder: Path to the directory containing command modules.
3134
"""
3235

33-
def __init__(self, plugin_folder: Path, *args, **kwargs):
36+
def __init__(self, plugin_folder: Path, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
3437
"""Initializer function."""
3538
super().__init__(*args, **kwargs)
3639
self.plugin_folder = plugin_folder
3740
self.known_modules = KNOWN_MODULES
3841

39-
def list_commands(self, ctx: click.Context) -> list[str]:
42+
def list_commands(self, _ctx: click.Context) -> list[str]:
4043
"""List commands available under `commands` module."""
4144
rv = []
4245
for filename in self.plugin_folder.iterdir():
@@ -47,7 +50,7 @@ def list_commands(self, ctx: click.Context) -> list[str]:
4750
rv.sort()
4851
return rv
4952

50-
def get_command(self, ctx: click.Context, name: str) -> Callable | None:
53+
def get_command(self, _ctx: click.Context, name: str) -> Callable | None:
5154
"""Get command implementation from `commands` module."""
5255
try:
5356
filepath = self.plugin_folder / f"{name}.py"

src/mdio/api/accessor.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from __future__ import annotations
44

5+
from typing import TYPE_CHECKING
6+
57
import logging
68

7-
import dask.array as da
89
import numpy as np
910
import numpy.typing as npt
1011
import zarr
@@ -18,6 +19,9 @@
1819
from mdio.core.exceptions import MDIONotFoundError
1920
from mdio.exceptions import ShapeError
2021

22+
if TYPE_CHECKING:
23+
import dask.array as da
24+
from numpy.typing import NDArray
2125

2226
logger = logging.getLogger(__name__)
2327

@@ -181,7 +185,7 @@ def __init__(
181185
self._set_attributes()
182186
self._open_arrays()
183187

184-
def _validate_store(self, storage_options):
188+
def _validate_store(self, storage_options: dict[str, str] | None) -> None:
185189
"""Method to validate the provided store."""
186190
if storage_options is None:
187191
storage_options = {}
@@ -194,7 +198,7 @@ def _validate_store(self, storage_options):
194198
disk_cache=self._disk_cache,
195199
)
196200

197-
def _connect(self):
201+
def _connect(self) -> None:
198202
"""Open the zarr root."""
199203
try:
200204
if self.mode in {"r", "r+"}:
@@ -212,11 +216,11 @@ def _connect(self):
212216
)
213217
raise MDIONotFoundError(msg) from e
214218

215-
def _deserialize_grid(self):
219+
def _deserialize_grid(self) -> None:
216220
"""Deserialize grid from Zarr metadata."""
217221
self.grid = Grid.from_zarr(self.root)
218222

219-
def _set_attributes(self):
223+
def _set_attributes(self) -> None:
220224
"""Deserialize attributes from Zarr metadata."""
221225
self.trace_count = self.root.attrs["trace_count"]
222226
self.stats = {
@@ -231,7 +235,7 @@ def _set_attributes(self):
231235
self.n_dim = len(self.shape)
232236

233237
# Access pattern attributes
234-
data_array_name = "_".join(["chunked", self.access_pattern])
238+
data_array_name = f"chunked_{self.access_pattern}"
235239
self.chunks = self._data_group[data_array_name].chunks
236240
self._orig_chunks = self.chunks
237241

@@ -251,15 +255,15 @@ def _set_attributes(self):
251255
self._orig_chunks = self.chunks
252256
self.chunks = new_chunks
253257

254-
def _open_arrays(self):
258+
def _open_arrays(self) -> None:
255259
"""Open arrays with requested backend."""
256-
data_array_name = "_".join(["chunked", self.access_pattern])
257-
header_array_name = "_".join(["chunked", self.access_pattern, "trace_headers"])
260+
data_array_name = f"chunked_{self.access_pattern}"
261+
header_array_name = f"chunked_{self.access_pattern}_trace_headers"
258262

259-
trace_kwargs = dict(
260-
group_handle=self._data_group,
261-
name=data_array_name,
262-
)
263+
trace_kwargs = {
264+
"group_handle": self._data_group,
265+
"name": data_array_name,
266+
}
263267

264268
if self._backend == "dask":
265269
trace_kwargs["chunks"] = self.chunks
@@ -271,10 +275,10 @@ def _open_arrays(self):
271275
logger.info(f"Setting MDIO in-memory chunks to {dask_chunks}")
272276
self.chunks = dask_chunks
273277

274-
header_kwargs = dict(
275-
group_handle=self._metadata_group,
276-
name=header_array_name,
277-
)
278+
header_kwargs = {
279+
"group_handle": self._metadata_group,
280+
"name": header_array_name,
281+
}
278282

279283
if self._backend == "dask":
280284
header_kwargs["chunks"] = self.chunks[:-1]
@@ -406,7 +410,7 @@ def __setitem__(self, key: int | tuple, value: npt.ArrayLike) -> None:
406410

407411
def coord_to_index(
408412
self,
409-
*args,
413+
*args: int | list[int],
410414
dimensions: str | list[str] | None = None,
411415
) -> tuple[NDArray[int], ...]:
412416
"""Convert dimension coordinate to zero-based index.
@@ -437,6 +441,7 @@ def coord_to_index(
437441
to indicies of that dimension
438442
439443
Raises:
444+
KeyError: if a requested dimension doesn't exist.
440445
ShapeError: if number of queries don't match requested dimensions.
441446
ValueError: if requested coordinates don't exist.
442447
@@ -490,10 +495,15 @@ def coord_to_index(
490495
if dimensions is None:
491496
dims = self.grid.dims
492497
else:
493-
dims = [self.grid.select_dim(dim_name) for dim_name in dimensions]
494-
495-
dim_indices = tuple()
496-
for mdio_dim, dim_query_coords in zip(dims, queries): # noqa: B905
498+
for query_dim in dimensions:
499+
try:
500+
dims.append(self.grid.select_dim(query_dim))
501+
except ValueError as err:
502+
msg = f"Requested dimension {query_dim} does not exist."
503+
raise KeyError(msg) from err
504+
505+
dim_indices = ()
506+
for mdio_dim, dim_query_coords in zip(dims, queries):
497507
# Make sure all coordinates exist.
498508
query_diff = np.setdiff1d(dim_query_coords, mdio_dim.coords)
499509
if len(query_diff) > 0:
@@ -510,14 +520,14 @@ def coord_to_index(
510520

511521
return dim_indices if len(dim_indices) > 1 else dim_indices[0]
512522

513-
def copy(
523+
def copy( # noqa: PLR0913
514524
self,
515525
dest_path_or_buffer: str,
516526
excludes: str = "",
517527
includes: str = "",
518528
storage_options: dict | None = None,
519529
overwrite: bool = False,
520-
):
530+
) -> None:
521531
"""Makes a copy of an MDIO file with or without all arrays.
522532
523533
Refer to mdio.api.convenience.copy for full documentation.
@@ -576,7 +586,7 @@ class MDIOReader(MDIOAccessor):
576586
`fsspec` documentation for more details.
577587
"""
578588

579-
def __init__(
589+
def __init__( # noqa: PLR0913
580590
self,
581591
mdio_path_or_buffer: str,
582592
access_pattern: str = "012",
@@ -632,7 +642,7 @@ class MDIOWriter(MDIOAccessor):
632642
`fsspec` documentation for more details.
633643
"""
634644

635-
def __init__(
645+
def __init__( # noqa: PLR0913
636646
self,
637647
mdio_path_or_buffer: str,
638648
access_pattern: str = "012",

src/mdio/api/convenience.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414

1515
if TYPE_CHECKING:
16+
from pathlib import Path
17+
from typing import Any
18+
1619
from numcodecs.abc import Codec
1720
from numpy.typing import NDArray
1821
from zarr import Array
@@ -23,10 +26,10 @@
2326

2427
def copy_mdio( # noqa: PLR0913
2528
source: MDIOReader,
26-
dest_path_or_buffer: str,
29+
dest_path_or_buffer: str | Path,
2730
excludes: str = "",
2831
includes: str = "",
29-
storage_options: dict | None = None,
32+
storage_options: dict[str, Any] | None = None,
3033
overwrite: bool = False,
3134
) -> None:
3235
"""Copy MDIO file.
@@ -135,7 +138,8 @@ def create_rechunk_plan(
135138

136139
for chunks, suffix in zip(chunks_list, suffix_list): # noqa: B905
137140
norm_chunks = [
138-
min(chunk, size) for chunk, size in zip(chunks, source.shape) # noqa: B905
141+
min(chunk, size)
142+
for chunk, size in zip(chunks, source.shape) # noqa: B905
139143
]
140144

141145
if suffix == source.access_pattern:

src/mdio/api/io_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22

33
from __future__ import annotations
44

5+
from typing import TYPE_CHECKING
56
from typing import Any
67

78
import dask.array as da
89
import zarr
910
from zarr.storage import FSStore
1011

12+
if TYPE_CHECKING:
13+
from pathlib import Path
14+
1115

1216
def process_url(
13-
url: str,
14-
mode: str,
17+
url: str | Path,
18+
mode: str | Path,
1519
storage_options: dict[str, Any],
1620
memory_cache_size: int,
1721
disk_cache: bool,
@@ -82,20 +86,20 @@ def process_url(
8286
... )
8387
"""
8488
if disk_cache is True:
85-
url = "::".join(["simplecache", url])
89+
url = f"simplecache::{url}"
8690

8791
# Strip whitespaces and slashes from end of string
88-
url = url.rstrip("/ ")
92+
url_str = str(url).rstrip("/ ")
8993

9094
# Flag for checking write access
91-
check = True if mode == "w" else False
95+
# check = True if mode == "w" else False
9296

9397
# TODO: Turning off write checking now because zarr has a bug.
9498
# Get rid of this once bug is fixed.
9599
check = False
96100

97101
store = FSStore(
98-
url=url,
102+
url=url_str,
99103
check=check,
100104
create=check,
101105
mode=mode,
@@ -123,7 +127,9 @@ def open_zarr_array(group_handle: zarr.Group, name: str) -> zarr.Array:
123127
return group_handle[name]
124128

125129

126-
def open_zarr_array_dask(group_handle: zarr.Group, name: str, **kwargs) -> da.Array:
130+
def open_zarr_array_dask(
131+
group_handle: zarr.Group, name: str, **kwargs: dict[str, Any]
132+
) -> da.Array:
127133
"""Open Zarr array lazily using Dask.
128134
129135
Note: All other kwargs get passed to dask.array.from_zarr()

src/mdio/commands/copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
help="Flag to overwrite if mdio file if it exists",
5656
show_default=True,
5757
)
58-
def copy(
58+
def copy( # noqa: PLR0913
5959
source_mdio_path: str,
6060
target_mdio_path: str,
6161
access_pattern: str = "012",

src/mdio/commands/segy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def segy_import( # noqa: PLR0913
393393
show_default=True,
394394
show_choices=True,
395395
)
396-
def segy_export(
396+
def segy_export( # noqa: PLR0913
397397
mdio_file: str,
398398
segy_path: str,
399399
access_pattern: str,

src/mdio/converters/exceptions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
class EnvironmentFormatError(Exception):
55
"""Raised when environment variable is of the wrong format."""
66

7-
def __init__(self, name, format, msg: str = ""):
7+
def __init__(self, name: str, format_: str, msg: str = "") -> None:
88
"""Initialize error."""
99
self.message = (
10-
f"Environment variable: {name} not of expected format: {format}. "
10+
f"Environment variable: {name} not of expected format: {format_}. "
1111
)
1212
self.message += f"\n{msg}" if msg else ""
1313
super().__init__(self.message)
@@ -16,7 +16,7 @@ def __init__(self, name, format, msg: str = ""):
1616
class GridTraceCountError(Exception):
1717
"""Raised when grid trace counts don't match the SEG-Y trace count."""
1818

19-
def __init__(self, grid_traces, segy_traces):
19+
def __init__(self, grid_traces: int, segy_traces: int) -> None:
2020
"""Initialize error."""
2121
self.message = (
2222
f"{grid_traces} != {segy_traces}"
@@ -32,7 +32,7 @@ def __init__(self, grid_traces, segy_traces):
3232
class GridTraceSparsityError(Exception):
3333
"""Raised when mdio grid will be sparsely populated from SEG-Y traces."""
3434

35-
def __init__(self, shape, num_traces, msg: str = ""):
35+
def __init__(self, shape: tuple[int, ...], num_traces: int, msg: str = "") -> None:
3636
"""Initialize error."""
3737
self.message = (
3838
f"Grid shape: {shape} but SEG-Y tracecount: {num_traces}. "

src/mdio/converters/mdio.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
from os import path
7+
from pathlib import Path
78
from tempfile import TemporaryDirectory
89

910
import numpy as np
@@ -28,8 +29,8 @@
2829

2930

3031
def mdio_to_segy( # noqa: C901
31-
mdio_path_or_buffer: str,
32-
output_segy_path: str,
32+
mdio_path_or_buffer: str | Path,
33+
output_segy_path: str | Path,
3334
endian: str = "big",
3435
access_pattern: str = "012",
3536
storage_options: dict = None,

0 commit comments

Comments
 (0)