|
9 | 9 |
|
10 | 10 |
|
11 | 11 | @pytest.fixture |
12 | | -def example_result(example_result_dict): |
13 | | - return PredictorEvaluationResult.build(example_result_dict) |
| 12 | +def example_cv_result(example_cv_result_dict): |
| 13 | + return PredictorEvaluationResult.build(example_cv_result_dict) |
14 | 14 |
|
15 | 15 |
|
16 | | -def test_indexing(example_result): |
17 | | - assert example_result.responses == {"saltiness", "salt?"} |
18 | | - assert example_result.metrics == {RMSE(), PVA(), F1()} |
19 | | - assert set(example_result["salt?"]) == {repr(F1()), repr(PVA())} |
20 | | - assert set(example_result) == {"salt?", "saltiness"} |
| 16 | +@pytest.fixture |
| 17 | +def example_holdout_result(example_holdout_result_dict): |
| 18 | + return PredictorEvaluationResult.build(example_holdout_result_dict) |
| 19 | + |
| 20 | + |
| 21 | +def test_indexing(example_cv_result, example_holdout_result): |
| 22 | + assert example_cv_result.responses == {"saltiness", "salt?"} |
| 23 | + assert example_holdout_result.responses == {"sweetness"} |
| 24 | + assert example_cv_result.metrics == {RMSE(), PVA(), F1()} |
| 25 | + assert example_holdout_result.metrics == {RMSE()} |
| 26 | + assert set(example_cv_result["salt?"]) == {repr(F1()), repr(PVA())} |
| 27 | + assert set(example_cv_result) == {"salt?", "saltiness"} |
| 28 | + assert set(example_holdout_result["sweetness"]) == {repr(RMSE())} |
| 29 | + assert set(example_holdout_result) == {"sweetness"} |
| 30 | + |
21 | 31 |
|
| 32 | +def test_cv_serde(example_cv_result, example_cv_result_dict): |
| 33 | + round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_cv_result_dict))) |
| 34 | + assert example_cv_result.evaluator == round_trip.evaluator |
22 | 35 |
|
23 | | -def test_serde(example_result, example_result_dict): |
24 | | - round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_result_dict))) |
25 | | - assert example_result.evaluator == round_trip.evaluator |
26 | 36 |
|
| 37 | +def test_holdout_serde(example_holdout_result, example_holdout_result_dict): |
| 38 | + round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_holdout_result_dict))) |
| 39 | + assert example_holdout_result.evaluator == round_trip.evaluator |
27 | 40 |
|
28 | | -def test_evaluator(example_result, example_evaluator_dict): |
29 | | - args = example_evaluator_dict |
| 41 | +def test_evaluator(example_cv_result, example_cv_evaluator_dict): |
| 42 | + args = example_cv_evaluator_dict |
30 | 43 | del args["type"] |
31 | 44 | expected = CrossValidationEvaluator(**args) |
32 | | - assert example_result.evaluator == expected |
33 | | - assert example_result.evaluator != 0 # make sure eq does something for mismatched classes |
| 45 | + assert example_cv_result.evaluator == expected |
| 46 | + assert example_cv_result.evaluator != 0 # make sure eq does something for mismatched classes |
34 | 47 |
|
35 | 48 |
|
36 | | -def test_check_rmse(example_result, example_rmse_metrics): |
37 | | - assert example_result["saltiness"]["rmse"].mean == example_rmse_metrics["mean"] |
38 | | - assert example_result["saltiness"][RMSE()].standard_error == example_rmse_metrics["standard_error"] |
| 49 | +def test_check_rmse(example_cv_result, example_rmse_metrics): |
| 50 | + assert example_cv_result["saltiness"]["rmse"].mean == example_rmse_metrics["mean"] |
| 51 | + assert example_cv_result["saltiness"][RMSE()].standard_error == example_rmse_metrics["standard_error"] |
39 | 52 | # check eq method does something |
40 | | - assert example_result["saltiness"][RMSE()] != 0 |
| 53 | + assert example_cv_result["saltiness"][RMSE()] != 0 |
41 | 54 | with pytest.raises(TypeError): |
42 | | - foo = example_result["saltiness"][0] |
| 55 | + foo = example_cv_result["saltiness"][0] |
43 | 56 |
|
44 | 57 |
|
45 | | -def test_real_pva(example_result, example_real_pva_metrics): |
| 58 | +def test_real_pva(example_cv_result, example_real_pva_metrics): |
46 | 59 | args = example_real_pva_metrics["value"][0] |
47 | 60 | expected = PredictedVsActualRealPoint.build(args) |
48 | | - assert example_result["saltiness"]["predicted_vs_actual"][0].predicted == expected.predicted |
49 | | - assert next(iter(example_result["saltiness"]["predicted_vs_actual"])).actual == expected.actual |
| 61 | + assert example_cv_result["saltiness"]["predicted_vs_actual"][0].predicted == expected.predicted |
| 62 | + assert next(iter(example_cv_result["saltiness"]["predicted_vs_actual"])).actual == expected.actual |
50 | 63 |
|
51 | 64 |
|
52 | | -def test_categorical_pva(example_result, example_categorical_pva_metrics): |
| 65 | +def test_categorical_pva(example_cv_result, example_categorical_pva_metrics): |
53 | 66 | args = example_categorical_pva_metrics["value"][0] |
54 | 67 | expected = PredictedVsActualCategoricalPoint.build(args) |
55 | | - assert example_result["salt?"]["predicted_vs_actual"][0].predicted == expected.predicted |
56 | | - assert next(iter(example_result["salt?"]["predicted_vs_actual"])).actual == expected.actual |
| 68 | + assert example_cv_result["salt?"]["predicted_vs_actual"][0].predicted == expected.predicted |
| 69 | + assert next(iter(example_cv_result["salt?"]["predicted_vs_actual"])).actual == expected.actual |
0 commit comments