diff --git a/icechunk-python/python/icechunk/dask.py b/icechunk-python/python/icechunk/dask.py index b4f24916c..3a802acbb 100644 --- a/icechunk-python/python/icechunk/dask.py +++ b/icechunk-python/python/icechunk/dask.py @@ -1,6 +1,6 @@ import functools from collections.abc import Callable, Mapping -from typing import Any, ParamSpec, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, TypeVar import numpy as np from packaging.version import Version @@ -13,6 +13,12 @@ from icechunk.distributed import extract_session, merge_sessions from icechunk.session import ForkSession +if TYPE_CHECKING: + try: + from zarr.core.metadata import ArrayV3Metadata + except ImportError: + ArrayV3Metadata = Any # type: ignore[misc,assignment] + SimpleGraph: TypeAlias = Mapping[tuple[str, int], tuple[Any, ...]] @@ -57,7 +63,7 @@ def _assert_correct_dask_version() -> None: def store_dask( *, sources: list[Array], - targets: list[zarr.Array], + targets: "list[zarr.Array[ArrayV3Metadata]]", regions: list[tuple[slice, ...]] | None = None, split_every: int | None = None, **store_kwargs: Any, diff --git a/icechunk-python/python/icechunk/distributed.py b/icechunk-python/python/icechunk/distributed.py index c1a2608e8..7613a103b 100644 --- a/icechunk-python/python/icechunk/distributed.py +++ b/icechunk-python/python/icechunk/distributed.py @@ -1,11 +1,17 @@ # distributed utility functions from collections.abc import Generator, Iterable -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import zarr from icechunk import IcechunkStore from icechunk.session import ForkSession, Session +if TYPE_CHECKING: + try: + from zarr.core.metadata import ArrayV3Metadata + except ImportError: + ArrayV3Metadata = Any # type: ignore[misc,assignment] + __all__ = [ "extract_session", "merge_sessions", @@ -25,7 +31,7 @@ def _flatten(seq: Iterable[Any], container: type = list) -> Generator[Any, None, def extract_session( - zarray: zarr.Array, axis: Any = None, keepdims: Any = None + zarray: "zarr.Array[ArrayV3Metadata]", axis: Any = None, keepdims: Any = None ) -> Session: """ Extract Icechunk Session from a zarr Array, useful for distributed array computing frameworks. diff --git a/icechunk-python/python/icechunk/testing/strategies.py b/icechunk-python/python/icechunk/testing/strategies.py index 435c374d2..adc865500 100644 --- a/icechunk-python/python/icechunk/testing/strategies.py +++ b/icechunk-python/python/icechunk/testing/strategies.py @@ -1,16 +1,21 @@ from collections.abc import Iterable -from typing import cast +from typing import TYPE_CHECKING, Any import hypothesis.strategies as st import icechunk as ic import zarr -from zarr.core.metadata import ArrayV3Metadata + +if TYPE_CHECKING: + try: + from zarr.core.metadata import ArrayV3Metadata + except ImportError: + ArrayV3Metadata = Any # type: ignore[misc,assignment] @st.composite def splitting_configs( - draw: st.DrawFn, *, arrays: Iterable[zarr.Array] + draw: st.DrawFn, *, arrays: "Iterable[zarr.Array[ArrayV3Metadata]]" ) -> ic.ManifestSplittingConfig: config_dict: dict[ ic.ManifestSplitCondition, @@ -29,7 +34,7 @@ def splitting_configs( else: array_condition = ic.ManifestSplitCondition.path_matches(array.path) dimnames = ( - cast(ArrayV3Metadata, array.metadata).dimension_names or (None,) * array.ndim + getattr(array.metadata, "dimension_names", None) or (None,) * array.ndim ) dimsize_axis_names = draw( st.lists( diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 6fd53e8ed..dbf39bfdd 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -15,6 +15,11 @@ from xarray.backends.common import ArrayWriter from xarray.backends.zarr import ZarrStore +try: + from zarr.core.metadata import ArrayV3Metadata +except ImportError: + ArrayV3Metadata = Any # type: ignore[misc,assignment] + __all__ = ["to_icechunk"] Region = Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None @@ -57,7 +62,7 @@ def __init__(self) -> None: super().__init__() # type: ignore[no-untyped-call] self.eager_sources: list[np.ndarray[Any, Any]] = [] - self.eager_targets: list[zarr.Array] = [] + self.eager_targets: list[zarr.Array[ArrayV3Metadata]] = [] self.eager_regions: list[tuple[slice, ...]] = [] def add(self, source: Any, target: Any, region: Any = None) -> Any: