Skip to content

Commit 3057be8

Browse files
committed
Fix: mask and scale problem for non-varied data
1 parent c109e30 commit 3057be8

File tree

3 files changed

+110
-31
lines changed

3 files changed

+110
-31
lines changed

kaleidoscope/algorithms/codec.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,83 @@ def decode(
9292
@override
9393
def name(self) -> str:
9494
return "decode"
95+
96+
97+
class Encode(BlockAlgorithm):
98+
"""
99+
The algorithm to encode data according to CF conventions.
100+
"""
101+
102+
def __init__(self, dtype: np.dtype, m: int):
103+
"""
104+
Creates a new algorithm instance.
105+
106+
:param dtype: The result data type.
107+
:param m: The number of input data dimensions.
108+
"""
109+
super().__init__(dtype, m, m)
110+
111+
@override
112+
def chunks(self, *inputs: da.Array) -> tuple[int, ...] | None:
113+
return None
114+
115+
@property
116+
@override
117+
def created_axes(self) -> list[int] | None:
118+
return None
119+
120+
@property
121+
@override
122+
def dropped_axes(self) -> list[int]:
123+
return []
124+
125+
# noinspection PyMethodMayBeStatic
126+
def encode(
127+
self,
128+
x: np.ndarray,
129+
*,
130+
add_offset: Any = None,
131+
scale_factor: Any = None,
132+
fill_value: Any = None,
133+
valid_min: Any = None,
134+
valid_max: Any = None,
135+
) -> np.ndarray:
136+
"""
137+
Encodes data.
138+
139+
:param x: The data.
140+
:param add_offset: The add-offset.
141+
:param scale_factor: The scale factor.
142+
:param fill_value: The fill value.
143+
:param valid_min: The valid minimum.
144+
:param valid_max: The valid maximum.
145+
:return: The encoded data.
146+
"""
147+
if (
148+
fill_value is None
149+
and add_offset is None
150+
and scale_factor is None
151+
and valid_min is None
152+
and valid_max is None
153+
):
154+
y = x
155+
else:
156+
y = x.astype(np.double)
157+
if add_offset is not None:
158+
y = y - add_offset
159+
if scale_factor is not None:
160+
y = y / scale_factor
161+
if valid_max is not None:
162+
y[y > valid_max] = valid_max
163+
if valid_min is not None:
164+
y[y < valid_min] = valid_min
165+
if fill_value is not None:
166+
y[np.isnan(x)] = fill_value
167+
return y
168+
169+
compute_block = encode
170+
171+
@property
172+
@override
173+
def name(self) -> str:
174+
return "encode"

kaleidoscope/operators/randomizeop.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from xarray import Dataset
1818

1919
from ..algorithms.codec import Decode
20+
from ..algorithms.codec import Encode
2021
from ..algorithms.randomize import Randomize
2122
from ..generators import DefaultGenerator
2223
from ..interface.logging import Logging
@@ -36,10 +37,21 @@ def _hash(name: str) -> int:
3637
return h
3738

3839

39-
def _decode(
40-
x: da.Array, a: dict[str:Any], dtype: np.dtype = np.double
41-
) -> da.Array:
42-
f = Decode(dtype, x.ndim)
40+
def _decode(x: da.Array, a: dict[str:Any]) -> da.Array:
41+
f = Decode(np.single if x.dtype == np.single else np.double, x.ndim)
42+
y = f.apply_to(
43+
x,
44+
add_offset=a.get("add_offset", None),
45+
scale_factor=a.get("scale_factor", None),
46+
fill_value=a.get("_FillValue", None),
47+
valid_min=a.get("valid_min", None),
48+
valid_max=a.get("valid_max", None),
49+
)
50+
return y
51+
52+
53+
def _encode(x: da.Array, a: dict[str:Any], dtype: np.dtype) -> da.Array:
54+
f = Encode(dtype, x.ndim)
4355
y = f.apply_to(
4456
x,
4557
add_offset=a.get("add_offset", None),
@@ -79,11 +91,7 @@ def run(self, source: Dataset) -> Dataset: # noqa: D102
7991
:return: The result dataset.
8092
"""
8193
source_id = source.attrs.get(
82-
"tracking_id",
83-
source.attrs.get(
84-
"uuid",
85-
f"{uuid.uuid5(uuid.NAMESPACE_URL, self._args.source_file.stem)}",
86-
),
94+
"tracking_id", source.attrs.get("uuid", f"{self.uuid}")
8795
)
8896
target: Dataset = Dataset(
8997
data_vars=source.data_vars,
@@ -130,7 +138,10 @@ def run(self, source: Dataset) -> Dataset: # noqa: D102
130138
clip=a.get("clip", None),
131139
)
132140
target[v] = DataArray(
133-
data=z, coords=x.coords, dims=x.dims, attrs=x.attrs
141+
data=_encode(z, x.attrs, x.dtype),
142+
coords=x.coords,
143+
dims=x.dims,
144+
attrs=x.attrs,
134145
)
135146
if "actual_range" in target[v].attrs:
136147
target[v].attrs["actual_range"] = np.array(
@@ -140,7 +151,6 @@ def run(self, source: Dataset) -> Dataset: # noqa: D102
140151
],
141152
dtype=z.dtype,
142153
)
143-
target[v].attrs["dtype"] = x.dtype
144154
target[v].attrs["entropy"] = np.array(s, dtype=np.int64)
145155
if get_logger().is_enabled(Logging.DEBUG):
146156
get_logger().debug(f"entropy: {s}")
@@ -162,6 +172,7 @@ def config(self) -> dict[str : dict[str:Any]]:
162172
config = json.load(r)
163173
return config
164174

175+
# noinspection PyShadowingNames
165176
def entropy(self, name: str, uuid: str, n: int = 4) -> list[int]:
166177
"""
167178
Returns the entropy of the seed sequence used for a given variable.
@@ -179,3 +190,10 @@ def entropy(self, name: str, uuid: str, n: int = 4) -> list[int]:
179190
seed = _hash(f"{name}-{uuid}") + self._args.selector
180191
g = DefaultGenerator(Philox(seed))
181192
return [g.next() for _ in range(n)]
193+
194+
@property
195+
def uuid(self) -> uuid.UUID:
196+
"""
197+
Returns a UUID constructed from the basename of the source file.
198+
"""
199+
return uuid.uuid5(uuid.NAMESPACE_URL, self._args.source_file.stem)

kaleidoscope/writer.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from typing import Literal
1111

1212
import numpy as np
13-
from dask.array import Array
1413
from typing_extensions import override
15-
from xarray import DataArray
1614
from xarray import Dataset
1715

1816
from .interface.writing import Writing
@@ -113,10 +111,7 @@ def _encode(self, dataset: Dataset, to_zarr: bool = True):
113111
encodings: dict[str, dict[str, Any]] = {}
114112

115113
for name, array in dataset.data_vars.items():
116-
dtype = array.attrs.pop("dtype", array.dtype)
117-
attrs = array.attrs
118114
data = array.data
119-
120115
dims: list = list(array.dims)
121116
if array.ndim == 0: # not an array
122117
continue
@@ -138,7 +133,7 @@ def _encode(self, dataset: Dataset, to_zarr: bool = True):
138133
else:
139134
chunks.append(data.chunksize[i])
140135
encodings[name] = self._encode_compress(
141-
dtype, attrs, chunks, to_zarr
136+
data.dtype, chunks, to_zarr
142137
)
143138
return encodings
144139

@@ -180,28 +175,14 @@ def _shuffle(self) -> bool:
180175
"""This method does not belong to public API."""
181176
return self._config[_KEY_SHUFFLE] == "true"
182177

183-
@staticmethod
184-
def _encode_variable(
185-
name: str, dims: list[str], attrs: dict[str, Any], array: Array
186-
) -> DataArray:
187-
"""This method does not belong to public API."""
188-
return DataArray(data=array, dims=dims, name=name, attrs=attrs)
189-
190178
def _encode_compress(
191179
self,
192180
dtype: np.dtype,
193-
attrs: dict[str:Any],
194181
chunks: list[int],
195182
to_zarr: bool = True,
196183
) -> dict[str, Any]:
197184
"""This method does not belong to public API."""
198185
enc = {"dtype": dtype}
199-
if "_FillValue" in attrs:
200-
enc["_FillValue"] = attrs.pop("_FillValue")
201-
if "add_offset" in attrs:
202-
enc["add_offset"] = attrs.pop("add_offset")
203-
if "scale_factor" in attrs:
204-
enc["scale_factor"] = attrs.pop("scale_factor")
205186
if chunks:
206187
if to_zarr:
207188
enc["chunks"] = tuple(chunks)

0 commit comments

Comments
 (0)