Skip to content

Commit 1ca2a76

Browse files
committed
Fit test for ensembles
1 parent 0693ec5 commit 1ca2a76

File tree

1 file changed

+35
-0
lines changed
  • tests/test_approximators/test_approximator_ensemble

1 file changed

+35
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import re
2+
3+
import io
4+
from contextlib import redirect_stdout
5+
6+
7+
def test_loss_progress(continuous_approximator_ensemble, train_dataset_for_ensemble, validation_dataset):
8+
continuous_approximator_ensemble.compile(optimizer="AdamW")
9+
num_epochs = 3
10+
11+
# Capture ostream and train model
12+
with io.StringIO() as stream:
13+
with redirect_stdout(stream):
14+
continuous_approximator_ensemble.fit(
15+
dataset=train_dataset_for_ensemble, validation_data=validation_dataset, epochs=num_epochs
16+
)
17+
18+
output = stream.getvalue()
19+
20+
print(output)
21+
22+
# check that there is a progress bar
23+
assert "━" in output, "no progress bar"
24+
25+
# check that the loss is shown
26+
assert "loss" in output
27+
assert re.search(r"\bloss: \d+\.\d+", output) is not None, "training loss not correctly shown"
28+
29+
# check that validation loss is shown
30+
assert "val_loss" in output
31+
assert re.search(r"\bval_loss: \d+\.\d+", output) is not None, "validation loss not correctly shown"
32+
33+
# check that the shown loss is not nan or zero
34+
assert re.search(r"\bnan\b", output) is None, "found nan in output"
35+
assert re.search(r"\bloss: 0\.0000e\+00\b", output) is None, "found zero loss in output"

0 commit comments

Comments
 (0)