Skip to content

Commit cf3f397

Browse files
committed
use up-to-date keras code
1 parent 9fa0743 commit cf3f397

File tree

3 files changed

+69
-75
lines changed

3 files changed

+69
-75
lines changed

bayesflow/approximators/backend_approximators/jax_approximator.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,6 @@
88

99

1010
class JAXApproximator(keras.Model):
11-
def _aggregate_logs(self, logs, step_logs):
12-
if not logs:
13-
return step_logs
14-
15-
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16-
17-
def _mean_logs(self, logs, total_steps):
18-
if total_steps == 0:
19-
return logs
20-
21-
def _div(x):
22-
return x / total_steps
23-
24-
return keras.tree.map_structure(_div, logs)
25-
2611
# noinspection PyMethodOverriding
2712
def compute_metrics(self, *args, **kwargs) -> dict[str, jax.Array]:
2813
# implemented by each respective architecture
@@ -38,6 +23,7 @@ def evaluate(
3823
steps=None,
3924
callbacks=None,
4025
return_dict=False,
26+
aggregate=True,
4127
**kwargs,
4228
):
4329
self._assert_compile_called("evaluate")
@@ -49,7 +35,8 @@ def evaluate(
4935
if use_cached_eval_dataset:
5036
epoch_iterator = self._eval_epoch_iterator
5137
else:
52-
# Create an iterator that yields batches of input/target data.
38+
# Create an iterator that yields batches of
39+
# input/target data.
5340
epoch_iterator = JAXEpochIterator(
5441
x=x,
5542
y=y,
@@ -82,12 +69,25 @@ def evaluate(
8269
total_steps = 0
8370
self.reset_metrics()
8471

72+
def _aggregate_fn(_logs, _step_logs):
73+
if not _logs:
74+
return _step_logs
75+
76+
return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)
77+
78+
def _reduce_fn(_logs, _total_steps):
79+
def _div(val):
80+
return val / _total_steps
81+
82+
return keras.tree.map_structure(_div, _logs)
83+
8584
self._jax_state_synced = True
8685
with epoch_iterator.catch_stop_iteration():
8786
for step, iterator in epoch_iterator:
88-
total_steps += 1
8987
callbacks.on_test_batch_begin(step)
9088

89+
total_steps += 1
90+
9191
if self._jax_state_synced:
9292
# The state may have been synced by a callback.
9393
state = self._get_jax_state(
@@ -98,16 +98,17 @@ def evaluate(
9898
)
9999
self._jax_state_synced = False
100100

101-
# BAYESFLOW: save into step_logs instead of overwriting logs
102101
step_logs, state = self.test_function(state, iterator)
103102
(
104103
trainable_variables,
105104
non_trainable_variables,
106105
metrics_variables,
107106
) = state
108107

109-
# BAYESFLOW: aggregate the metrics across all iterations
110-
logs = self._aggregate_logs(logs, step_logs)
108+
if aggregate:
109+
logs = _aggregate_fn(logs, step_logs)
110+
else:
111+
logs = step_logs
111112

112113
# Setting _jax_state enables callbacks to force a state sync
113114
# if they need to.
@@ -120,22 +121,18 @@ def evaluate(
120121
}
121122

122123
# Dispatch callbacks. This takes care of async dispatch.
123-
callbacks.on_test_batch_end(step, logs)
124+
callbacks.on_test_batch_end(step, step_logs)
124125

125126
if self.stop_evaluating:
126127
break
127128

128-
# BAYESFLOW: average the metrics across all iterations
129-
logs = self._mean_logs(logs, total_steps)
129+
if aggregate:
130+
logs = _reduce_fn(logs, total_steps)
130131

131132
# Reattach state back to model (if not already done by a callback).
132133
self.jax_state_sync()
133134

134-
# The jax spmd_mode is need for multi-process context, since the
135-
# metrics values are replicated, and we don't want to do a all
136-
# gather, and only need the local copy of the value.
137-
with jax.spmd_mode("allow_all"):
138-
logs = self._get_metrics_result_or_logs(logs)
135+
logs = self._get_metrics_result_or_logs(logs)
139136
callbacks.on_test_end(logs)
140137
self._jax_state = None
141138
if not use_cached_eval_dataset:

bayesflow/approximators/backend_approximators/tensorflow_approximator.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,6 @@
88

99

1010
class TensorFlowApproximator(keras.Model):
11-
def _aggregate_logs(self, logs, step_logs):
12-
if not logs:
13-
return step_logs
14-
15-
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16-
17-
def _mean_logs(self, logs, total_steps):
18-
if total_steps == 0:
19-
return logs
20-
21-
def _div(x):
22-
return x / total_steps
23-
24-
return keras.tree.map_structure(_div, logs)
25-
2611
# noinspection PyMethodOverriding
2712
def compute_metrics(self, *args, **kwargs) -> dict[str, tf.Tensor]:
2813
# implemented by each respective architecture
@@ -38,6 +23,7 @@ def evaluate(
3823
steps=None,
3924
callbacks=None,
4025
return_dict=False,
26+
aggregate=False,
4127
**kwargs,
4228
):
4329
self._assert_compile_called("evaluate")
@@ -81,24 +67,37 @@ def evaluate(
8167
logs = {}
8268
total_steps = 0
8369
self.reset_metrics()
70+
71+
def _aggregate_fn(_logs, _step_logs):
72+
if not _logs:
73+
return _step_logs
74+
75+
return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)
76+
77+
def _reduce_fn(_logs, _total_steps):
78+
def _div(val):
79+
return val / _total_steps
80+
81+
return keras.tree.map_structure(_div, _logs)
82+
8483
with epoch_iterator.catch_stop_iteration():
8584
for step, iterator in epoch_iterator:
86-
total_steps += 1
87-
8885
callbacks.on_test_batch_begin(step)
86+
total_steps += 1
8987

90-
# BAYESFLOW: save into step_logs instead of overwriting logs
9188
step_logs = self.test_function(iterator)
9289

93-
# BAYESFLOW: aggregate the metrics across all iterations
94-
logs = self._aggregate_logs(logs, step_logs)
90+
if aggregate:
91+
logs = _aggregate_fn(logs, step_logs)
92+
else:
93+
logs = step_logs
9594

96-
callbacks.on_test_batch_end(step, logs)
95+
callbacks.on_test_batch_end(step, step_logs)
9796
if self.stop_evaluating:
9897
break
9998

100-
# BAYESFLOW: average the metrics across all iterations
101-
logs = self._mean_logs(logs, total_steps)
99+
if aggregate:
100+
logs = _reduce_fn(logs, total_steps)
102101

103102
logs = self._get_metrics_result_or_logs(logs)
104103
callbacks.on_test_end(logs)

bayesflow/approximators/backend_approximators/torch_approximator.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,6 @@
88

99

1010
class TorchApproximator(keras.Model):
11-
def _aggregate_logs(self, logs, step_logs):
12-
if not logs:
13-
return step_logs
14-
15-
return keras.tree.map_structure(keras.ops.add, logs, step_logs)
16-
17-
def _mean_logs(self, logs, total_steps):
18-
if total_steps == 0:
19-
return logs
20-
21-
def _div(x):
22-
return x / total_steps
23-
24-
return keras.tree.map_structure(_div, logs)
25-
2611
# noinspection PyMethodOverriding
2712
def compute_metrics(self, *args, **kwargs) -> dict[str, torch.Tensor]:
2813
# implemented by each respective architecture
@@ -38,6 +23,7 @@ def evaluate(
3823
steps=None,
3924
callbacks=None,
4025
return_dict=False,
26+
aggregate=False,
4127
**kwargs,
4228
):
4329
# TODO: respect compiled trainable state
@@ -82,23 +68,35 @@ def evaluate(
8268
logs = {}
8369
total_steps = 0
8470
self.reset_metrics()
85-
for step, data in epoch_iterator:
86-
total_steps += 1
8771

88-
callbacks.on_test_batch_begin(step)
72+
def _aggregate_fn(_logs, _step_logs):
73+
if not _logs:
74+
return _step_logs
75+
76+
return keras.tree.map_structure(keras.ops.add, _logs, _step_logs)
77+
78+
def _reduce_fn(_logs, _total_steps):
79+
def _div(val):
80+
return val / _total_steps
81+
82+
return keras.tree.map_structure(_div, _logs)
8983

90-
# BAYESFLOW: save into step_logs instead of overwriting logs
84+
for step, data in epoch_iterator:
85+
callbacks.on_test_batch_begin(step)
86+
total_steps += 1
9187
step_logs = self.test_function(data)
9288

93-
# BAYESFLOW: aggregate the metrics across all iterations
94-
logs = self._aggregate_logs(logs, step_logs)
89+
if aggregate:
90+
logs = _aggregate_fn(logs, step_logs)
91+
else:
92+
logs = step_logs
9593

96-
callbacks.on_test_batch_end(step, logs)
94+
callbacks.on_test_batch_end(step, step_logs)
9795
if self.stop_evaluating:
9896
break
9997

100-
# BAYESFLOW: average the metrics across all iterations
101-
logs = self._mean_logs(logs, total_steps)
98+
if aggregate:
99+
logs = _reduce_fn(logs, total_steps)
102100

103101
logs = self._get_metrics_result_or_logs(logs)
104102
callbacks.on_test_end(logs)

0 commit comments

Comments
 (0)