Skip to content

Commit 28b3678

Browse files
authored
Fix constructor deserialization of NumericalTarget (#740)
Fixes #739 by adjusting the deserialization hooks.
2 parents 05dcec0 + ec70c29 commit 28b3678

File tree

4 files changed

+45
-23
lines changed

4 files changed

+45
-23
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## [Unreleased]
8+
### Fixed
9+
- Deserialization of `NumericalTarget` objects using the optional `constructor` field
10+
711
## [0.14.2] - 2026-01-14
812
### Added
913
- `NumericalTarget.match_*` constructors now accept a `mismatch_instead` argument. If

baybe/serialization/core.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import base64
6+
import contextlib
67
import pickle
78
from datetime import datetime, timedelta
89
from typing import TYPE_CHECKING, Any, NoReturn, TypeVar, get_type_hints
@@ -129,8 +130,13 @@ def block_deserialization_hook(_: Any, cls: type) -> NoReturn: # noqa: DOC101,
129130
def select_constructor_hook(specs: dict, cls: type[_T]) -> _T:
130131
"""Use the constructor specified in the 'constructor' field for deserialization."""
131132
# If a constructor is specified, use it
132-
specs = specs.copy()
133133
if constructor_name := specs.pop("constructor", None):
134+
# Drop potentially existing type field
135+
# (The type is already fully determined in this execution branch)
136+
specs = specs.copy()
137+
specs.pop(_TYPE_FIELD, None)
138+
139+
# Extract the constructor callable
134140
constructor = getattr(cls, constructor_name)
135141

136142
# If given a non-attrs class, simply call the constructor
@@ -139,9 +145,14 @@ def select_constructor_hook(specs: dict, cls: type[_T]) -> _T:
139145

140146
# Extract the constructor parameter types and deserialize the arguments
141147
type_hints = get_type_hints(constructor)
142-
for key, value in specs.items():
148+
for key in specs:
143149
annotation = type_hints[key]
144-
specs[key] = converter.structure(specs[key], annotation)
150+
151+
# For some types (e.g. unions), there might not be a registered structure
152+
# hook. In this case, the constructor will accept the raw value, so we
153+
# simply pass it through.
154+
with contextlib.suppress(cattrs.StructureHandlerNotFoundError):
155+
specs[key] = converter.structure(specs[key], annotation)
145156

146157
# Call the constructor with the deserialized arguments
147158
return constructor(**specs)

baybe/targets/numerical.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from baybe.exceptions import IncompatibilityError
1919
from baybe.serialization import SerialMixin, converter
20-
from baybe.serialization.core import select_constructor_hook
20+
from baybe.serialization.core import _TYPE_FIELD, select_constructor_hook
2121
from baybe.targets._deprecated import (
2222
_VALID_TRANSFORMATIONS,
2323
TargetMode,
@@ -765,12 +765,16 @@ def summary(self):
765765
)
766766

767767

768+
# Collect leftover original slotted classes processed by `attrs.define`
769+
gc.collect()
770+
768771
converter.register_unstructure_hook(
769772
NumericalTarget,
770773
cattrs.gen.make_dict_unstructure_fn(
771774
NumericalTarget, converter, _constructor_info=cattrs.override(omit=False)
772775
),
773776
)
777+
converter.register_structure_hook(NumericalTarget, select_constructor_hook)
774778

775779

776780
# >>> Deprecation >>> #
@@ -779,26 +783,11 @@ def summary(self):
779783

780784

781785
@converter.register_structure_hook
782-
def _(dct, cls) -> NumericalTarget:
786+
def _enable_legacy_target_deserialization(dct: dict[str, Any], cls) -> NumericalTarget:
783787
if "mode" in dct:
784-
return _hook(*dct)
785-
return select_constructor_hook(dct, cls)
786-
787-
788-
_hook = converter.get_structure_hook(NumericalTarget)
789-
790-
791-
@converter.register_structure_hook
792-
def _structure_legacy_target_arguments(x: dict[str, Any], _) -> NumericalTarget:
793-
"""Accept legacy target argument for backward compatibility."""
794-
x.pop("type", None)
795-
try:
796-
return _hook(x, _)
797-
except Exception:
798-
return NumericalTarget(**x) # type: ignore[return-value]
788+
dct.pop(_TYPE_FIELD, None)
789+
return NumericalTarget(**dct)
790+
return _hook(dct, cls)
799791

800792

801793
# <<< Deprecation <<< #
802-
803-
804-
gc.collect()

tests/test_deprecations.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
linear_transform,
4040
triangular_transform,
4141
)
42+
from baybe.targets.base import Target
4243
from baybe.targets.binary import BinaryTarget
4344
from baybe.transformations.basic import AffineTransformation
4445

@@ -474,3 +475,20 @@ def test_target_transformation(
474475
if deprecation is not None:
475476
assert_series_equal(deprecation.transform(series), expected)
476477
assert_series_equal(modern.transform(series), expected)
478+
479+
480+
def test_deserialization_using_constructor():
481+
"""Deserialization using the 'constructor' field works despite having other
482+
deprecation mechanisms in place.""" # noqa
483+
config = """
484+
{
485+
"type": "NumericalTarget",
486+
"name": "t_max_bounds",
487+
"constructor": "normalized_ramp",
488+
"cutoffs": [0, 100]
489+
}
490+
"""
491+
t_old = NumericalTarget("t_max_bounds", mode="MAX", bounds=(0, 100))
492+
t_new = NumericalTarget.normalized_ramp("t_max_bounds", cutoffs=(0, 100))
493+
t_new_config = Target.from_json(config)
494+
assert t_old == t_new == t_new_config

0 commit comments

Comments
 (0)