Skip to content

Commit 5f38e86

Browse files
committed
compare metric in assert layer/model equal
1 parent 66f8ca7 commit 5f38e86

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

tests/utils/assertions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def assert_models_equal(model1: keras.Model, model2: keras.Model):
1313
else:
1414
assert_layers_equal(layer1, layer2)
1515

16+
assert len(model1.metrics) == len(model2.metrics)
17+
for metric1, metric2 in zip(model1.metrics, model2.metrics):
18+
assert type(metric1) is type(metric2)
19+
assert metric1.name == metric2.name
20+
1621

1722
def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer):
1823
msg = f"Layers {layer1.name} and {layer2.name} have different types."
@@ -40,3 +45,8 @@ def assert_layers_equal(layer1: keras.Layer, layer2: keras.Layer):
4045
# this is turned off for now, see https://github.com/bayesflow-org/bayesflow/issues/412
4146
msg = f"Layers {layer1.name} and {layer2.name} have a different name."
4247
# assert layer1.name == layer2.name, msg
48+
49+
assert len(layer1.metrics) == len(layer2.metrics), f"metrics do not match: {layer1.metrics}!={layer2.metrics}"
50+
for metric1, metric2 in zip(layer1.metrics, layer2.metrics):
51+
assert type(metric1) is type(metric2)
52+
assert metric1.name == metric2.name

0 commit comments

Comments
 (0)