Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions icechunk-python/python/icechunk/dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from collections.abc import Callable, Mapping
from typing import Any, ParamSpec, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, TypeVar

import numpy as np
from packaging.version import Version
Expand All @@ -13,6 +13,12 @@
from icechunk.distributed import extract_session, merge_sessions
from icechunk.session import ForkSession

if TYPE_CHECKING:
try:
from zarr.core.metadata import ArrayV3Metadata
except ImportError:
ArrayV3Metadata = Any # type: ignore[misc,assignment]

SimpleGraph: TypeAlias = Mapping[tuple[str, int], tuple[Any, ...]]


Expand Down Expand Up @@ -57,7 +63,7 @@ def _assert_correct_dask_version() -> None:
def store_dask(
*,
sources: list[Array],
targets: list[zarr.Array],
targets: "list[zarr.Array[ArrayV3Metadata]]",
regions: list[tuple[slice, ...]] | None = None,
split_every: int | None = None,
**store_kwargs: Any,
Expand Down
10 changes: 8 additions & 2 deletions icechunk-python/python/icechunk/distributed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# distributed utility functions
from collections.abc import Generator, Iterable
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

import zarr
from icechunk import IcechunkStore
from icechunk.session import ForkSession, Session

if TYPE_CHECKING:
try:
from zarr.core.metadata import ArrayV3Metadata
except ImportError:
ArrayV3Metadata = Any # type: ignore[misc,assignment]

__all__ = [
"extract_session",
"merge_sessions",
Expand All @@ -25,7 +31,7 @@ def _flatten(seq: Iterable[Any], container: type = list) -> Generator[Any, None,


def extract_session(
zarray: zarr.Array, axis: Any = None, keepdims: Any = None
zarray: "zarr.Array[ArrayV3Metadata]", axis: Any = None, keepdims: Any = None
) -> Session:
"""
Extract Icechunk Session from a zarr Array, useful for distributed array computing frameworks.
Expand Down
13 changes: 9 additions & 4 deletions icechunk-python/python/icechunk/testing/strategies.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from collections.abc import Iterable
from typing import cast
from typing import TYPE_CHECKING, Any

import hypothesis.strategies as st

import icechunk as ic
import zarr
from zarr.core.metadata import ArrayV3Metadata

if TYPE_CHECKING:
try:
from zarr.core.metadata import ArrayV3Metadata
except ImportError:
ArrayV3Metadata = Any # type: ignore[misc,assignment]


@st.composite
def splitting_configs(
draw: st.DrawFn, *, arrays: Iterable[zarr.Array]
draw: st.DrawFn, *, arrays: "Iterable[zarr.Array[ArrayV3Metadata]]"
) -> ic.ManifestSplittingConfig:
config_dict: dict[
ic.ManifestSplitCondition,
Expand All @@ -29,7 +34,7 @@ def splitting_configs(
else:
array_condition = ic.ManifestSplitCondition.path_matches(array.path)
dimnames = (
cast(ArrayV3Metadata, array.metadata).dimension_names or (None,) * array.ndim
getattr(array.metadata, "dimension_names", None) or (None,) * array.ndim
)
dimsize_axis_names = draw(
st.lists(
Expand Down
7 changes: 6 additions & 1 deletion icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from xarray.backends.common import ArrayWriter
from xarray.backends.zarr import ZarrStore

try:
from zarr.core.metadata import ArrayV3Metadata
except ImportError:
ArrayV3Metadata = Any # type: ignore[misc,assignment]

__all__ = ["to_icechunk"]

Region = Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None
Expand Down Expand Up @@ -57,7 +62,7 @@ def __init__(self) -> None:
super().__init__() # type: ignore[no-untyped-call]

self.eager_sources: list[np.ndarray[Any, Any]] = []
self.eager_targets: list[zarr.Array] = []
self.eager_targets: list[zarr.Array[ArrayV3Metadata]] = []
self.eager_regions: list[tuple[slice, ...]] = []

def add(self, source: Any, target: Any, region: Any = None) -> Any:
Expand Down
Loading