22
33from __future__ import annotations
44
5+ from typing import TYPE_CHECKING
6+
57import logging
68
7- import dask .array as da
89import numpy as np
910import numpy .typing as npt
1011import zarr
1819from mdio .core .exceptions import MDIONotFoundError
1920from mdio .exceptions import ShapeError
2021
22+ if TYPE_CHECKING :
23+ import dask .array as da
24+ from numpy .typing import NDArray
2125
2226logger = logging .getLogger (__name__ )
2327
@@ -181,7 +185,7 @@ def __init__(
181185 self ._set_attributes ()
182186 self ._open_arrays ()
183187
184- def _validate_store (self , storage_options ) :
188+ def _validate_store (self , storage_options : dict [ str , str ] | None ) -> None :
185189 """Method to validate the provided store."""
186190 if storage_options is None :
187191 storage_options = {}
@@ -194,7 +198,7 @@ def _validate_store(self, storage_options):
194198 disk_cache = self ._disk_cache ,
195199 )
196200
197- def _connect (self ):
201+ def _connect (self ) -> None :
198202 """Open the zarr root."""
199203 try :
200204 if self .mode in {"r" , "r+" }:
@@ -212,11 +216,11 @@ def _connect(self):
212216 )
213217 raise MDIONotFoundError (msg ) from e
214218
215- def _deserialize_grid (self ):
219+ def _deserialize_grid (self ) -> None :
216220 """Deserialize grid from Zarr metadata."""
217221 self .grid = Grid .from_zarr (self .root )
218222
219- def _set_attributes (self ):
223+ def _set_attributes (self ) -> None :
220224 """Deserialize attributes from Zarr metadata."""
221225 self .trace_count = self .root .attrs ["trace_count" ]
222226 self .stats = {
@@ -231,7 +235,7 @@ def _set_attributes(self):
231235 self .n_dim = len (self .shape )
232236
233237 # Access pattern attributes
234- data_array_name = "_" . join ([ "chunked" , self .access_pattern ])
238+ data_array_name = f"chunked_ { self .access_pattern } "
235239 self .chunks = self ._data_group [data_array_name ].chunks
236240 self ._orig_chunks = self .chunks
237241
@@ -251,15 +255,15 @@ def _set_attributes(self):
251255 self ._orig_chunks = self .chunks
252256 self .chunks = new_chunks
253257
254- def _open_arrays (self ):
258+ def _open_arrays (self ) -> None :
255259 """Open arrays with requested backend."""
256- data_array_name = "_" . join ([ "chunked" , self .access_pattern ])
257- header_array_name = "_" . join ([ "chunked" , self .access_pattern , "trace_headers" ])
260+ data_array_name = f"chunked_ { self .access_pattern } "
261+ header_array_name = f"chunked_ { self .access_pattern } _trace_headers"
258262
259- trace_kwargs = dict (
260- group_handle = self ._data_group ,
261- name = data_array_name ,
262- )
263+ trace_kwargs = {
264+ " group_handle" : self ._data_group ,
265+ " name" : data_array_name ,
266+ }
263267
264268 if self ._backend == "dask" :
265269 trace_kwargs ["chunks" ] = self .chunks
@@ -271,10 +275,10 @@ def _open_arrays(self):
271275 logger .info (f"Setting MDIO in-memory chunks to { dask_chunks } " )
272276 self .chunks = dask_chunks
273277
274- header_kwargs = dict (
275- group_handle = self ._metadata_group ,
276- name = header_array_name ,
277- )
278+ header_kwargs = {
279+ " group_handle" : self ._metadata_group ,
280+ " name" : header_array_name ,
281+ }
278282
279283 if self ._backend == "dask" :
280284 header_kwargs ["chunks" ] = self .chunks [:- 1 ]
@@ -406,7 +410,7 @@ def __setitem__(self, key: int | tuple, value: npt.ArrayLike) -> None:
406410
407411 def coord_to_index (
408412 self ,
409- * args ,
413+ * args : int | list [ int ] ,
410414 dimensions : str | list [str ] | None = None ,
411415 ) -> tuple [NDArray [int ], ...]:
412416 """Convert dimension coordinate to zero-based index.
@@ -437,6 +441,7 @@ def coord_to_index(
437441 to indicies of that dimension
438442
439443 Raises:
444+ KeyError: if a requested dimension doesn't exist.
440445 ShapeError: if number of queries don't match requested dimensions.
441446 ValueError: if requested coordinates don't exist.
442447
@@ -490,10 +495,15 @@ def coord_to_index(
490495 if dimensions is None :
491496 dims = self .grid .dims
492497 else :
493- dims = [self .grid .select_dim (dim_name ) for dim_name in dimensions ]
494-
495- dim_indices = tuple ()
496- for mdio_dim , dim_query_coords in zip (dims , queries ): # noqa: B905
498+ for query_dim in dimensions :
499+ try :
500+ dims .append (self .grid .select_dim (query_dim ))
501+ except ValueError as err :
502+ msg = f"Requested dimension { query_dim } does not exist."
503+ raise KeyError (msg ) from err
504+
505+ dim_indices = ()
506+ for mdio_dim , dim_query_coords in zip (dims , queries ):
497507 # Make sure all coordinates exist.
498508 query_diff = np .setdiff1d (dim_query_coords , mdio_dim .coords )
499509 if len (query_diff ) > 0 :
@@ -510,14 +520,14 @@ def coord_to_index(
510520
511521 return dim_indices if len (dim_indices ) > 1 else dim_indices [0 ]
512522
513- def copy (
523+ def copy ( # noqa: PLR0913
514524 self ,
515525 dest_path_or_buffer : str ,
516526 excludes : str = "" ,
517527 includes : str = "" ,
518528 storage_options : dict | None = None ,
519529 overwrite : bool = False ,
520- ):
530+ ) -> None :
521531 """Makes a copy of an MDIO file with or without all arrays.
522532
523533 Refer to mdio.api.convenience.copy for full documentation.
@@ -576,7 +586,7 @@ class MDIOReader(MDIOAccessor):
576586 `fsspec` documentation for more details.
577587 """
578588
579- def __init__ (
589+ def __init__ ( # noqa: PLR0913
580590 self ,
581591 mdio_path_or_buffer : str ,
582592 access_pattern : str = "012" ,
@@ -632,7 +642,7 @@ class MDIOWriter(MDIOAccessor):
632642 `fsspec` documentation for more details.
633643 """
634644
635- def __init__ (
645+ def __init__ ( # noqa: PLR0913
636646 self ,
637647 mdio_path_or_buffer : str ,
638648 access_pattern : str = "012" ,
0 commit comments