Skip to content

Commit 4e4e101

Browse files
committed
Fix: chunking of coordinate dimensions
1 parent e99efae commit 4e4e101

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

kaleidoscope/writer.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any
1010
from typing import Literal
1111

12+
import dask.array as da
1213
import numpy as np
1314
from xarray import Dataset
1415

@@ -44,6 +45,18 @@
4445
"""
4546

4647

48+
def chunksize(data: da.Array | np.ndarray, i: int = 0) -> int:
49+
"""
50+
Returns the chunk size of given data along a given dimension.
51+
52+
@param x The data.
53+
@param i The dimension enumerator.
54+
@return The chunk size along the given dimension or the shape along
55+
the given dimension, if the data are not chunked.
56+
"""
57+
return data.chunksize[i] if hasattr(data, "chunksize") else data.shape[i]
58+
59+
4760
class Writer(Writing):
4861
"""! The target dataset writer."""
4962

@@ -108,28 +121,29 @@ def _encode(self, dataset: Dataset, to_zarr: bool = True):
108121
"""This method does not belong to public API."""
109122
encodings: dict[str, dict[str, Any]] = {}
110123

111-
for name, array in dataset.data_vars.items():
124+
for name, array in dataset.variables.items():
112125
data = array.data
113126
dims: list = list(array.dims)
114127
if array.ndim == 0: # not an array
115128
continue
116129
if name in dims: # a coordinate dimension
117-
continue
118-
chunks: list[int] = []
119-
for i, dim in enumerate(dims):
120-
if dim in self._chunks:
121-
chunk_size = self._chunks[dim]
122-
assert isinstance(chunk_size, int), (
123-
f"Invalid chunk size specified for "
124-
f"dimension '{dim}'"
125-
)
126-
if chunk_size == -1:
127-
chunk_size = data.shape[i]
128-
if chunk_size == 0:
129-
chunk_size = data.chunksize[i]
130-
chunks.append(chunk_size)
131-
else:
132-
chunks.append(data.chunksize[i])
130+
chunks: list[int] = [chunksize(data)]
131+
else:
132+
chunks: list[int] = []
133+
for i, dim in enumerate(dims):
134+
if dim in self._chunks:
135+
chunk_size = self._chunks[dim]
136+
assert isinstance(chunk_size, int), (
137+
f"Invalid chunk size specified for "
138+
f"dimension '{dim}'"
139+
)
140+
if chunk_size == -1:
141+
chunk_size = data.shape[i]
142+
if chunk_size == 0:
143+
chunk_size = chunksize(data, i)
144+
chunks.append(chunk_size)
145+
else:
146+
chunks.append(chunksize(data, i))
133147
encodings[name] = self._encode_compress(
134148
data.dtype, chunks, to_zarr
135149
)

0 commit comments

Comments
 (0)