Skip to content

Commit 9fa0743

Browse files
committed
aggregate logs in evaluate
1 parent 4781e2e commit 9fa0743

File tree

3 files changed

+329
-0
lines changed

3 files changed

+329
-0
lines changed

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,148 @@
33

44
from bayesflow.utils import filter_kwargs
55

6+
from keras.src.backend.jax.trainer import JAXEpochIterator
7+
from keras.src import callbacks as callbacks_module
8+
69

710
class JAXApproximator(keras.Model):
11+
def _aggregate_logs(self, logs, step_logs):
12+
if not logs:
13+
return step_logs
14+
15+
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16+
17+
def _mean_logs(self, logs, total_steps):
18+
if total_steps == 0:
19+
return logs
20+
21+
def _div(x):
22+
return x / total_steps
23+
24+
return keras.tree.map_structure(_div, logs)
25+
826
# noinspection PyMethodOverriding
927
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
1028
# implemented by each respective architecture
1129
raise NotImplementedError
1230

31+
def evaluate(
32+
self,
33+
x=None,
34+
y=None,
35+
batch_size=None,
36+
verbose="auto",
37+
sample_weight=None,
38+
steps=None,
39+
callbacks=None,
40+
return_dict=False,
41+
**kwargs,
42+
):
43+
self._assert_compile_called("evaluate")
44+
# TODO: respect compiled trainable state
45+
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
46+
if kwargs:
47+
raise ValueError(f"Arguments not recognized: {kwargs}")
48+
49+
if use_cached_eval_dataset:
50+
epoch_iterator = self._eval_epoch_iterator
51+
else:
52+
# Create an iterator that yields batches of input/target data.
53+
epoch_iterator = JAXEpochIterator(
54+
x=x,
55+
y=y,
56+
sample_weight=sample_weight,
57+
batch_size=batch_size,
58+
steps_per_epoch=steps,
59+
shuffle=False,
60+
steps_per_execution=self.steps_per_execution,
61+
)
62+
63+
self._symbolic_build(iterator=epoch_iterator)
64+
epoch_iterator.reset()
65+
66+
# Container that configures and calls callbacks.
67+
if not isinstance(callbacks, callbacks_module.CallbackList):
68+
callbacks = callbacks_module.CallbackList(
69+
callbacks,
70+
add_progbar=verbose != 0,
71+
verbose=verbose,
72+
epochs=1,
73+
steps=epoch_iterator.num_batches,
74+
model=self,
75+
)
76+
self._record_training_state_sharding_spec()
77+
78+
self.make_test_function()
79+
self.stop_evaluating = False
80+
callbacks.on_test_begin()
81+
logs = {}
82+
total_steps = 0
83+
self.reset_metrics()
84+
85+
self._jax_state_synced = True
86+
with epoch_iterator.catch_stop_iteration():
87+
for step, iterator in epoch_iterator:
88+
total_steps += 1
89+
callbacks.on_test_batch_begin(step)
90+
91+
if self._jax_state_synced:
92+
# The state may have been synced by a callback.
93+
state = self._get_jax_state(
94+
trainable_variables=True,
95+
non_trainable_variables=True,
96+
metrics_variables=True,
97+
purge_model_variables=True,
98+
)
99+
self._jax_state_synced = False
100+
101+
# BAYESFLOW: save into step_logs instead of overwriting logs
102+
step_logs, state = self.test_function(state, iterator)
103+
(
104+
trainable_variables,
105+
non_trainable_variables,
106+
metrics_variables,
107+
) = state
108+
109+
# BAYESFLOW: aggregate the metrics across all iterations
110+
logs = self._aggregate_logs(logs, step_logs)
111+
112+
# Setting _jax_state enables callbacks to force a state sync
113+
# if they need to.
114+
self._jax_state = {
115+
# I wouldn't recommend modifying non-trainable model state
116+
# during evaluate(), but it's allowed.
117+
"trainable_variables": trainable_variables,
118+
"non_trainable_variables": non_trainable_variables,
119+
"metrics_variables": metrics_variables,
120+
}
121+
122+
# Dispatch callbacks. This takes care of async dispatch.
123+
callbacks.on_test_batch_end(step, logs)
124+
125+
if self.stop_evaluating:
126+
break
127+
128+
# BAYESFLOW: average the metrics across all iterations
129+
logs = self._mean_logs(logs, total_steps)
130+
131+
# Reattach state back to model (if not already done by a callback).
132+
self.jax_state_sync()
133+
134+
# The jax spmd_mode is need for multi-process context, since the
135+
# metrics values are replicated, and we don't want to do a all
136+
# gather, and only need the local copy of the value.
137+
with jax.spmd_mode("allow_all"):
138+
logs = self._get_metrics_result_or_logs(logs)
139+
callbacks.on_test_end(logs)
140+
self._jax_state = None
141+
if not use_cached_eval_dataset:
142+
# Only clear sharding if evaluate is not called from `fit`.
143+
self._clear_jax_state_sharding()
144+
if return_dict:
145+
return logs
146+
return self._flatten_metrics_in_order(logs)
147+
13148
def stateless_compute_metrics(
14149
self,
15150
trainable_variables: any,

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,110 @@
33

44
from bayesflow.utils import filter_kwargs
55

6+
from keras.src.backend.tensorflow.trainer import TFEpochIterator
7+
from keras.src import callbacks as callbacks_module
8+
69

710
class TensorFlowApproximator(keras.Model):
11+
def _aggregate_logs(self, logs, step_logs):
12+
if not logs:
13+
return step_logs
14+
15+
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16+
17+
def _mean_logs(self, logs, total_steps):
18+
if total_steps == 0:
19+
return logs
20+
21+
def _div(x):
22+
return x / total_steps
23+
24+
return keras.tree.map_structure(_div, logs)
25+
826
# noinspection PyMethodOverriding
927
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
1028
# implemented by each respective architecture
1129
raise NotImplementedError
1230

31+
def evaluate(
32+
self,
33+
x=None,
34+
y=None,
35+
batch_size=None,
36+
verbose="auto",
37+
sample_weight=None,
38+
steps=None,
39+
callbacks=None,
40+
return_dict=False,
41+
**kwargs,
42+
):
43+
self._assert_compile_called("evaluate")
44+
# TODO: respect compiled trainable state
45+
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
46+
if kwargs:
47+
raise ValueError(f"Arguments not recognized: {kwargs}")
48+
49+
if use_cached_eval_dataset:
50+
epoch_iterator = self._eval_epoch_iterator
51+
else:
52+
# Create an iterator that yields batches of input/target data.
53+
epoch_iterator = TFEpochIterator(
54+
x=x,
55+
y=y,
56+
sample_weight=sample_weight,
57+
batch_size=batch_size,
58+
steps_per_epoch=steps,
59+
shuffle=False,
60+
distribute_strategy=self.distribute_strategy,
61+
steps_per_execution=self.steps_per_execution,
62+
)
63+
64+
self._maybe_symbolic_build(iterator=epoch_iterator)
65+
epoch_iterator.reset()
66+
67+
# Container that configures and calls callbacks.
68+
if not isinstance(callbacks, callbacks_module.CallbackList):
69+
callbacks = callbacks_module.CallbackList(
70+
callbacks,
71+
add_progbar=verbose != 0,
72+
verbose=verbose,
73+
epochs=1,
74+
steps=epoch_iterator.num_batches,
75+
model=self,
76+
)
77+
78+
self.make_test_function()
79+
self.stop_evaluating = False
80+
callbacks.on_test_begin()
81+
logs = {}
82+
total_steps = 0
83+
self.reset_metrics()
84+
with epoch_iterator.catch_stop_iteration():
85+
for step, iterator in epoch_iterator:
86+
total_steps += 1
87+
88+
callbacks.on_test_batch_begin(step)
89+
90+
# BAYESFLOW: save into step_logs instead of overwriting logs
91+
step_logs = self.test_function(iterator)
92+
93+
# BAYESFLOW: aggregate the metrics across all iterations
94+
logs = self._aggregate_logs(logs, step_logs)
95+
96+
callbacks.on_test_batch_end(step, logs)
97+
if self.stop_evaluating:
98+
break
99+
100+
# BAYESFLOW: average the metrics across all iterations
101+
logs = self._mean_logs(logs, total_steps)
102+
103+
logs = self._get_metrics_result_or_logs(logs)
104+
callbacks.on_test_end(logs)
105+
106+
if return_dict:
107+
return logs
108+
return self._flatten_metrics_in_order(logs)
109+
13110
def test_step(self, data: dict[str, any]) -> dict[str, tf.Tensor]:
14111
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15112
return self.compute_metrics(**kwargs)

bayesflow/approximators/backend_approximators/torch_approximator.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,110 @@
33

44
from bayesflow.utils import filter_kwargs
55

6+
from keras.src.backend.torch.trainer import TorchEpochIterator
7+
from keras.src import callbacks as callbacks_module
8+
69

710
class TorchApproximator(keras.Model):
11+
def _aggregate_logs(self, logs, step_logs):
12+
if not logs:
13+
return step_logs
14+
15+
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16+
17+
def _mean_logs(self, logs, total_steps):
18+
if total_steps == 0:
19+
return logs
20+
21+
def _div(x):
22+
return x / total_steps
23+
24+
return keras.tree.map_structure(_div, logs)
25+
826
# noinspection PyMethodOverriding
927
def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
1028
# implemented by each respective architecture
1129
raise NotImplementedError
1230

31+
def evaluate(
32+
self,
33+
x=None,
34+
y=None,
35+
batch_size=None,
36+
verbose="auto",
37+
sample_weight=None,
38+
steps=None,
39+
callbacks=None,
40+
return_dict=False,
41+
**kwargs,
42+
):
43+
# TODO: respect compiled trainable state
44+
use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
45+
if kwargs:
46+
raise ValueError(f"Arguments not recognized: {kwargs}")
47+
48+
if use_cached_eval_dataset:
49+
epoch_iterator = self._eval_epoch_iterator
50+
else:
51+
# Create an iterator that yields batches of input/target data.
52+
epoch_iterator = TorchEpochIterator(
53+
x=x,
54+
y=y,
55+
sample_weight=sample_weight,
56+
batch_size=batch_size,
57+
steps_per_epoch=steps,
58+
shuffle=False,
59+
steps_per_execution=self.steps_per_execution,
60+
)
61+
62+
self._symbolic_build(iterator=epoch_iterator)
63+
epoch_iterator.reset()
64+
65+
# Container that configures and calls callbacks.
66+
if not isinstance(callbacks, callbacks_module.CallbackList):
67+
callbacks = callbacks_module.CallbackList(
68+
callbacks,
69+
add_progbar=verbose != 0,
70+
verbose=verbose,
71+
epochs=1,
72+
steps=epoch_iterator.num_batches,
73+
model=self,
74+
)
75+
76+
# Switch the torch Module back to testing mode.
77+
self.eval()
78+
79+
self.make_test_function()
80+
self.stop_evaluating = False
81+
callbacks.on_test_begin()
82+
logs = {}
83+
total_steps = 0
84+
self.reset_metrics()
85+
for step, data in epoch_iterator:
86+
total_steps += 1
87+
88+
callbacks.on_test_batch_begin(step)
89+
90+
# BAYESFLOW: save into step_logs instead of overwriting logs
91+
step_logs = self.test_function(data)
92+
93+
# BAYESFLOW: aggregate the metrics across all iterations
94+
logs = self._aggregate_logs(logs, step_logs)
95+
96+
callbacks.on_test_batch_end(step, logs)
97+
if self.stop_evaluating:
98+
break
99+
100+
# BAYESFLOW: average the metrics across all iterations
101+
logs = self._mean_logs(logs, total_steps)
102+
103+
logs = self._get_metrics_result_or_logs(logs)
104+
callbacks.on_test_end(logs)
105+
106+
if return_dict:
107+
return logs
108+
return self._flatten_metrics_in_order(logs)
109+
13110
def test_step(self, data: dict[str, any]) -> dict[str, torch.Tensor]:
14111
kwargs = filter_kwargs(data | {"stage": "validation"}, self.compute_metrics)
15112
return self.compute_metrics(**kwargs)

0 commit comments

Comments
 (0)