|
1 | 1 | import bayesflow as bf |
| 2 | +import pytest |
2 | 3 |
|
3 | 4 |
|
4 | 5 | def num_variables(x: dict): |
@@ -47,3 +48,26 @@ def test_root_mean_squared_error(random_estimates, random_targets): |
47 | 48 | assert out["values"].shape == (num_variables(random_estimates),) |
48 | 49 | assert out["metric_name"] == "NRMSE" |
49 | 50 | assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] |
| 51 | + |
| 52 | + |
| 53 | +def test_expected_calibration_error(pred_models, true_models, model_names): |
| 54 | + out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=model_names) |
| 55 | + assert list(out.keys()) == ["values", "metric_name", "model_names"] |
| 56 | + assert out["values"].shape == (pred_models.shape[-1],) |
| 57 | + assert out["metric_name"] == "Expected Calibration Error" |
| 58 | + assert out["model_names"] == [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"] |
| 59 | + |
| 60 | + # returns probs? |
| 61 | + out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, return_probs=True) |
| 62 | + assert list(out.keys()) == ["values", "metric_name", "model_names", "probs_true", "probs_pred"] |
| 63 | + assert len(out["probs_true"]) == pred_models.shape[-1] |
| 64 | + assert len(out["probs_pred"]) == pred_models.shape[-1] |
| 65 | + # default: auto model names |
| 66 | + assert out["model_names"] == ["M_0", "M_1", "M_2"] |
| 67 | + |
| 68 | + # handles incorrect input? |
| 69 | + with pytest.raises(Exception): |
| 70 | + out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=["a"]) |
| 71 | + |
| 72 | + with pytest.raises(Exception): |
| 73 | + out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose) |
0 commit comments