Skip to content

Commit 25e632c

Browse files
authored
Merge pull request #407 from bayesflow-org/dev
Fixed weighted mean everywhere
2 parents fcdaae3 + 2bf0b53 commit 25e632c

File tree

20 files changed

+430
-55
lines changed

20 files changed

+430
-55
lines changed

bayesflow/adapters/adapter.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import MutableSequence, Sequence, Mapping
1+
from collections.abc import Callable, MutableSequence, Sequence, Mapping
22

33
import numpy as np
44

@@ -24,6 +24,7 @@
2424
NumpyTransform,
2525
OneHot,
2626
Rename,
27+
SerializableCustomTransform,
2728
Sqrt,
2829
Standardize,
2930
ToArray,
@@ -274,6 +275,88 @@ def apply(
274275
self.transforms.append(transform)
275276
return self
276277

278+
def apply_serializable(
279+
self,
280+
include: str | Sequence[str] = None,
281+
*,
282+
forward: Callable[[np.ndarray, ...], np.ndarray],
283+
inverse: Callable[[np.ndarray, ...], np.ndarray],
284+
predicate: Predicate = None,
285+
exclude: str | Sequence[str] = None,
286+
**kwargs,
287+
):
288+
"""Append a :py:class:`~transforms.SerializableCustomTransform` to the adapter.
289+
290+
Parameters
291+
----------
292+
forward : function, no lambda
293+
Registered serializable function to transform the data in the forward pass.
294+
For the adapter to be serializable, this function has to be serializable
295+
as well (see Notes). Therefore, only proper functions and no lambda
296+
functions can be used here.
297+
inverse : function, no lambda
298+
Registered serializable function to transform the data in the inverse pass.
299+
For the adapter to be serializable, this function has to be serializable
300+
as well (see Notes). Therefore, only proper functions and no lambda
301+
functions can be used here.
302+
predicate : Predicate, optional
303+
Function that indicates which variables should be transformed.
304+
include : str or Sequence of str, optional
305+
Names of variables to include in the transform.
306+
exclude : str or Sequence of str, optional
307+
Names of variables to exclude from the transform.
308+
**kwargs : dict
309+
Additional keyword arguments passed to the transform.
310+
311+
Raises
312+
------
313+
ValueError
314+
When the provided functions are not registered serializable functions.
315+
316+
Notes
317+
-----
318+
Important: The forward and inverse functions have to be registered with Keras.
319+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
320+
They must also be registered (and identical) when loading the adapter
321+
at a later point in time.
322+
323+
Examples
324+
--------
325+
326+
The example below shows how to use the
327+
`keras.saving.register_keras_serializable` decorator to
328+
register functions with Keras. Note that for this simple
329+
example, one usually would use the simpler :py:meth:`apply`
330+
method.
331+
332+
>>> import keras
333+
>>>
334+
>>> @keras.saving.register_keras_serializable("custom")
335+
>>> def forward_fn(x):
336+
>>> return x**2
337+
>>>
338+
>>> @keras.saving.register_keras_serializable("custom")
339+
>>> def inverse_fn(x):
340+
>>> return x**0.5
341+
>>>
342+
>>> adapter = bf.Adapter().apply_serializable(
343+
>>> "x",
344+
>>> forward=forward_fn,
345+
>>> inverse=inverse_fn,
346+
>>> )
347+
"""
348+
transform = FilterTransform(
349+
transform_constructor=SerializableCustomTransform,
350+
predicate=predicate,
351+
include=include,
352+
exclude=exclude,
353+
forward=forward,
354+
inverse=inverse,
355+
**kwargs,
356+
)
357+
self.transforms.append(transform)
358+
return self
359+
277360
def as_set(self, keys: str | Sequence[str]):
278361
"""Append an :py:class:`~transforms.AsSet` transform to the adapter.
279362

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .one_hot import OneHot
1616
from .rename import Rename
1717
from .scale import Scale
18+
from .serializable_custom_transform import SerializableCustomTransform
1819
from .shift import Shift
1920
from .sqrt import Sqrt
2021
from .standardize import Standardize
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from collections.abc import Callable
2+
import numpy as np
3+
from keras.saving import (
4+
deserialize_keras_object as deserialize,
5+
register_keras_serializable as serializable,
6+
serialize_keras_object as serialize,
7+
get_registered_name,
8+
get_registered_object,
9+
)
10+
from .elementwise_transform import ElementwiseTransform
11+
from ...utils import filter_kwargs
12+
import inspect
13+
14+
15+
@serializable(package="bayesflow.adapters")
16+
class SerializableCustomTransform(ElementwiseTransform):
17+
"""
18+
Transforms a parameter using a pair of registered serializable forward and inverse functions.
19+
20+
Parameters
21+
----------
22+
forward : function, no lambda
23+
Registered serializable function to transform the data in the forward pass.
24+
For the adapter to be serializable, this function has to be serializable
25+
as well (see Notes). Therefore, only proper functions and no lambda
26+
functions can be used here.
27+
inverse : function, no lambda
28+
Function to transform the data in the inverse pass.
29+
For the adapter to be serializable, this function has to be serializable
30+
as well (see Notes). Therefore, only proper functions and no lambda
31+
functions can be used here.
32+
33+
Raises
34+
------
35+
ValueError
36+
When the provided functions are not registered serializable functions.
37+
38+
Notes
39+
-----
40+
Important: The forward and inverse functions have to be registered with Keras.
41+
To do so, use the `@keras.saving.register_keras_serializable` decorator.
42+
They must also be registered (and identical) when loading the adapter
43+
at a later point in time.
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
*,
50+
forward: Callable[[np.ndarray, ...], np.ndarray],
51+
inverse: Callable[[np.ndarray, ...], np.ndarray],
52+
):
53+
super().__init__()
54+
55+
self._check_serializable(forward, label="forward")
56+
self._check_serializable(inverse, label="inverse")
57+
self._forward = forward
58+
self._inverse = inverse
59+
60+
@classmethod
61+
def _check_serializable(cls, function, label=""):
62+
GENERAL_EXAMPLE_CODE = (
63+
"The example code below shows the structure of a correctly decorated function:\n\n"
64+
"```\n"
65+
"import keras\n\n"
66+
"@keras.saving.register_keras_serializable('custom')\n"
67+
f"def my_{label}(...):\n"
68+
" [your code goes here...]\n"
69+
"```\n"
70+
)
71+
if function is None:
72+
raise TypeError(
73+
f"'{label}' must be a registered serializable function, was 'NoneType'.\n{GENERAL_EXAMPLE_CODE}"
74+
)
75+
registered_name = get_registered_name(function)
76+
# check if function is a lambda function
77+
if registered_name == "<lambda>":
78+
raise ValueError(
79+
f"The provided function for '{label}' is a lambda function, "
80+
"which cannot be serialized. "
81+
"Please provide a registered serializable function by using the "
82+
"@keras.saving.register_keras_serializable decorator."
83+
f"\n{GENERAL_EXAMPLE_CODE}"
84+
)
85+
if inspect.ismethod(function):
86+
raise ValueError(
87+
f"The provided value for '{label}' is a method, not a function. "
88+
"Methods cannot be serialized separately from their classes. "
89+
"Please provide a registered serializable function instead by "
90+
"moving the functionality to a function (i.e., outside of the class) and "
91+
"using the @keras.saving.register_keras_serializable decorator."
92+
f"\n{GENERAL_EXAMPLE_CODE}"
93+
)
94+
registered_object_for_name = get_registered_object(registered_name)
95+
if registered_object_for_name is None:
96+
try:
97+
source_max_lines = 5
98+
function_source_code = inspect.getsource(function).split("\n")
99+
if len(function_source_code) > source_max_lines:
100+
function_source_code = function_source_code[:source_max_lines] + [" [...]"]
101+
102+
example_code = "For your provided function, this would look like this:\n\n"
103+
example_code += "\n".join(
104+
["```", "import keras\n", "@keras.saving.register_keras_serializable('custom')"]
105+
+ function_source_code
106+
+ ["```"]
107+
)
108+
except OSError:
109+
example_code = GENERAL_EXAMPLE_CODE
110+
raise ValueError(
111+
f"The provided function for '{label}' is not registered with Keras.\n"
112+
"Please register the function using the "
113+
"@keras.saving.register_keras_serializable decorator.\n"
114+
f"{example_code}"
115+
)
116+
if registered_object_for_name is not function:
117+
raise ValueError(
118+
f"The provided function for '{label}' does not match the function "
119+
f"registered under its name '{registered_name}'. "
120+
f"(registered function: {registered_object_for_name}, provided function: {function}). "
121+
)
122+
123+
@classmethod
124+
def from_config(cls, config: dict, custom_objects=None) -> "SerializableCustomTransform":
125+
if get_registered_object(config["forward"]["config"], custom_objects) is None:
126+
provided_function_msg = ""
127+
if config["_forward_source_code"]:
128+
provided_function_msg = (
129+
f"\nThe originally provided function was:\n\n```\n{config['_forward_source_code']}\n```"
130+
)
131+
raise TypeError(
132+
"\n\nPLEASE READ HERE:\n"
133+
"-----------------\n"
134+
"The forward function that was provided as `forward` "
135+
"is not registered with Keras, making deserialization impossible. "
136+
f"Please ensure that it is registered as '{config['forward']['config']}' and identical to the original "
137+
"function before loading your model."
138+
f"{provided_function_msg}"
139+
)
140+
if get_registered_object(config["inverse"]["config"], custom_objects) is None:
141+
provided_function_msg = ""
142+
if config["_inverse_source_code"]:
143+
provided_function_msg = (
144+
f"\nThe originally provided function was:\n\n```\n{config['_inverse_source_code']}\n```"
145+
)
146+
raise TypeError(
147+
"\n\nPLEASE READ HERE:\n"
148+
"-----------------\n"
149+
"The inverse function that was provided as `inverse` "
150+
"is not registered with Keras, making deserialization impossible. "
151+
f"Please ensure that it is registered as '{config['inverse']['config']}' and identical to the original "
152+
"function before loading your model."
153+
f"{provided_function_msg}"
154+
)
155+
forward = deserialize(config["forward"], custom_objects)
156+
inverse = deserialize(config["inverse"], custom_objects)
157+
return cls(
158+
forward=forward,
159+
inverse=inverse,
160+
)
161+
162+
def get_config(self) -> dict:
163+
forward_source_code = inverse_source_code = None
164+
try:
165+
forward_source_code = inspect.getsource(self._forward)
166+
inverse_source_code = inspect.getsource(self._inverse)
167+
except OSError:
168+
pass
169+
return {
170+
"forward": serialize(self._forward),
171+
"inverse": serialize(self._inverse),
172+
"_forward_source_code": forward_source_code,
173+
"_inverse_source_code": inverse_source_code,
174+
}
175+
176+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
177+
# filter kwargs so that other transform args like batch_size, strict, ... are not passed through
178+
kwargs = filter_kwargs(kwargs, self._forward)
179+
return self._forward(data, **kwargs)
180+
181+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
182+
kwargs = filter_kwargs(kwargs, self._inverse)
183+
return self._inverse(data, **kwargs)

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
vjp,
1313
serialize_value_or_type,
1414
deserialize_value_or_type,
15+
weighted_mean,
1516
)
1617

1718
from bayesflow.networks import InferenceNetwork
@@ -240,6 +241,6 @@ def decode(z):
240241
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
241242

242243
losses = maximum_likelihood_loss + self.beta * reconstruction_loss
243-
loss = self.aggregate(losses, sample_weight)
244+
loss = weighted_mean(losses, sample_weight)
244245

245246
return base_metrics | {"loss": loss}

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import numpy as np
88

99
from bayesflow.types import Tensor
10-
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum
10+
from bayesflow.utils import (
11+
find_network,
12+
keras_kwargs,
13+
serialize_value_or_type,
14+
deserialize_value_or_type,
15+
weighted_mean,
16+
)
1117

1218

1319
from ..inference_network import InferenceNetwork
@@ -331,6 +337,6 @@ def compute_metrics(
331337

332338
# Pseudo-huber loss, see [2], Section 3.3
333339
loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber)
334-
loss = weighted_sum(loss, sample_weight)
340+
loss = weighted_mean(loss, sample_weight)
335341

336342
return base_metrics | {"loss": loss}

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
keras_kwargs,
88
serialize_value_or_type,
99
deserialize_value_or_type,
10-
weighted_sum,
10+
weighted_mean,
1111
)
1212

1313
from .actnorm import ActNorm
@@ -167,6 +167,6 @@ def compute_metrics(
167167
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
168168

169169
z, log_density = self(x, conditions=conditions, inverse=False, density=True)
170-
loss = weighted_sum(-log_density, sample_weight)
170+
loss = weighted_mean(-log_density, sample_weight)
171171

172172
return base_metrics | {"loss": loss}

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
optimal_transport,
1414
serialize_value_or_type,
1515
deserialize_value_or_type,
16-
weighted_sum,
16+
weighted_mean,
1717
)
1818
from ..inference_network import InferenceNetwork
1919

@@ -260,6 +260,6 @@ def compute_metrics(
260260
predicted_velocity = self.velocity(x, time=t, conditions=conditions, training=stage == "training")
261261

262262
loss = self.loss_fn(target_velocity, predicted_velocity)
263-
loss = weighted_sum(loss, sample_weight)
263+
loss = weighted_mean(loss, sample_weight)
264264

265265
return base_metrics | {"loss": loss}

bayesflow/scores/normed_difference_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.types import Shape, Tensor
5+
from bayesflow.utils import weighted_mean
56

67
from .scoring_rule import ScoringRule
78

@@ -55,7 +56,7 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
5556
"""
5657
estimates = estimates["value"]
5758
scores = keras.ops.absolute(estimates - targets) ** self.k
58-
score = self.aggregate(scores, weights)
59+
score = weighted_mean(scores, weights)
5960
return score
6061

6162
def get_config(self):

bayesflow/scores/parametric_distribution_score.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from keras.saving import register_keras_serializable as serializable
22

33
from bayesflow.types import Tensor
4+
from bayesflow.utils import weighted_mean
45

56
from .scoring_rule import ScoringRule
67

@@ -29,5 +30,5 @@ def score(self, estimates: dict[str, Tensor], targets: Tensor, weights: Tensor =
2930
:math:`S(\hat p_\phi, \theta; k) = -\log(\hat p_\phi(\theta))`
3031
"""
3132
scores = -self.log_prob(x=targets, **estimates)
32-
score = self.aggregate(scores, weights)
33+
score = weighted_mean(scores, weights)
3334
return score

0 commit comments

Comments
 (0)