|
9 | 9 | from typing import Any |
10 | 10 | from typing import Literal |
11 | 11 |
|
| 12 | +import dask.array as da |
12 | 13 | import numpy as np |
13 | 14 | from xarray import Dataset |
14 | 15 |
|
|
44 | 45 | """ |
45 | 46 |
|
46 | 47 |
|
| 48 | +def chunksize(data: da.Array | np.ndarray, i: int = 0) -> int: |
| 49 | + """ |
| 50 | + Returns the chunk size of given data along a given dimension. |
| 51 | +
|
| 52 | + @param x The data. |
| 53 | + @param i The dimension enumerator. |
| 54 | + @return The chunk size along the given dimension or the shape along |
| 55 | + the given dimension, if the data are not chunked. |
| 56 | + """ |
| 57 | + return data.chunksize[i] if hasattr(data, "chunksize") else data.shape[i] |
| 58 | + |
| 59 | + |
47 | 60 | class Writer(Writing): |
48 | 61 | """! The target dataset writer.""" |
49 | 62 |
|
@@ -108,28 +121,29 @@ def _encode(self, dataset: Dataset, to_zarr: bool = True): |
108 | 121 | """This method does not belong to public API.""" |
109 | 122 | encodings: dict[str, dict[str, Any]] = {} |
110 | 123 |
|
111 | | - for name, array in dataset.data_vars.items(): |
| 124 | + for name, array in dataset.variables.items(): |
112 | 125 | data = array.data |
113 | 126 | dims: list = list(array.dims) |
114 | 127 | if array.ndim == 0: # not an array |
115 | 128 | continue |
116 | 129 | if name in dims: # a coordinate dimension |
117 | | - continue |
118 | | - chunks: list[int] = [] |
119 | | - for i, dim in enumerate(dims): |
120 | | - if dim in self._chunks: |
121 | | - chunk_size = self._chunks[dim] |
122 | | - assert isinstance(chunk_size, int), ( |
123 | | - f"Invalid chunk size specified for " |
124 | | - f"dimension '{dim}'" |
125 | | - ) |
126 | | - if chunk_size == -1: |
127 | | - chunk_size = data.shape[i] |
128 | | - if chunk_size == 0: |
129 | | - chunk_size = data.chunksize[i] |
130 | | - chunks.append(chunk_size) |
131 | | - else: |
132 | | - chunks.append(data.chunksize[i]) |
| 130 | + chunks: list[int] = [chunksize(data)] |
| 131 | + else: |
| 132 | + chunks: list[int] = [] |
| 133 | + for i, dim in enumerate(dims): |
| 134 | + if dim in self._chunks: |
| 135 | + chunk_size = self._chunks[dim] |
| 136 | + assert isinstance(chunk_size, int), ( |
| 137 | + f"Invalid chunk size specified for " |
| 138 | + f"dimension '{dim}'" |
| 139 | + ) |
| 140 | + if chunk_size == -1: |
| 141 | + chunk_size = data.shape[i] |
| 142 | + if chunk_size == 0: |
| 143 | + chunk_size = chunksize(data, i) |
| 144 | + chunks.append(chunk_size) |
| 145 | + else: |
| 146 | + chunks.append(chunksize(data, i)) |
133 | 147 | encodings[name] = self._encode_compress( |
134 | 148 | data.dtype, chunks, to_zarr |
135 | 149 | ) |
|
0 commit comments