|
| 1 | +import re |
| 2 | + |
| 3 | +import io |
| 4 | +from contextlib import redirect_stdout |
| 5 | + |
| 6 | + |
| 7 | +def test_loss_progress(continuous_approximator_ensemble, train_dataset_for_ensemble, validation_dataset): |
| 8 | + continuous_approximator_ensemble.compile(optimizer="AdamW") |
| 9 | + num_epochs = 3 |
| 10 | + |
| 11 | + # Capture ostream and train model |
| 12 | + with io.StringIO() as stream: |
| 13 | + with redirect_stdout(stream): |
| 14 | + continuous_approximator_ensemble.fit( |
| 15 | + dataset=train_dataset_for_ensemble, validation_data=validation_dataset, epochs=num_epochs |
| 16 | + ) |
| 17 | + |
| 18 | + output = stream.getvalue() |
| 19 | + |
| 20 | + print(output) |
| 21 | + |
| 22 | + # check that there is a progress bar |
| 23 | + assert "━" in output, "no progress bar" |
| 24 | + |
| 25 | + # check that the loss is shown |
| 26 | + assert "loss" in output |
| 27 | + assert re.search(r"\bloss: \d+\.\d+", output) is not None, "training loss not correctly shown" |
| 28 | + |
| 29 | + # check that validation loss is shown |
| 30 | + assert "val_loss" in output |
| 31 | + assert re.search(r"\bval_loss: \d+\.\d+", output) is not None, "validation loss not correctly shown" |
| 32 | + |
| 33 | + # check that the shown loss is not nan or zero |
| 34 | + assert re.search(r"\bnan\b", output) is None, "found nan in output" |
| 35 | + assert re.search(r"\bloss: 0\.0000e\+00\b", output) is None, "found zero loss in output" |
0 commit comments