-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_run_suite.py
More file actions
89 lines (80 loc) · 3.09 KB
/
test_run_suite.py
File metadata and controls
89 lines (80 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Tests whether the main function of the package runs without errors and produces the expected output."""
import os
import tempfile
from argparse import Namespace
import pytest
from drevalpy.utils import main
from drevalpy.visualization.utils import parse_results, prep_results
@pytest.mark.parametrize(
"args",
[
{
"run_id": "test_run",
"dataset_name": "TOYv1",
"models": ["NaiveCellLineMeanPredictor"],
"baselines": ["NaiveDrugMeanPredictor"],
"test_mode": ["LPO"],
"randomization_mode": ["SVRC"],
"randomization_type": "permutation",
"n_trials_robustness": 2,
"cross_study_datasets": ["GDSC1"],
"curve_curator": False,
"curve_curator_cores": 1,
"measure": "LN_IC50",
"overwrite": False,
"optim_metric": "RMSE",
"n_cv_splits": 2,
"response_transformation": "None",
"multiprocessing": False,
"path_data": "../data",
"model_checkpoint_dir": "TEMPORARY",
}
],
)
def test_run_suite(args):
"""
Tests run_suite.py, i.e., all functionality of the main package.
:param args: arguments for the main function
"""
temp_dir = tempfile.TemporaryDirectory()
args["path_out"] = temp_dir.name
args = Namespace(**args)
main(args)
assert os.listdir(temp_dir.name) == ["test_run"]
(
evaluation_results,
evaluation_results_per_drug,
evaluation_results_per_cell_line,
true_vs_pred,
) = parse_results(path_to_results=os.path.join(temp_dir.name, args.run_id), dataset="TOYv1")
(
evaluation_results,
evaluation_results_per_drug,
evaluation_results_per_cell_line,
true_vs_pred,
) = prep_results(
evaluation_results,
evaluation_results_per_drug,
evaluation_results_per_cell_line,
true_vs_pred,
)
assert len(evaluation_results.columns) == 22
assert len(evaluation_results_per_drug.columns) == 15
assert len(evaluation_results_per_cell_line.columns) == 15
assert len(true_vs_pred.columns) == 12
assert all(model in evaluation_results.algorithm.unique() for model in args.models)
assert all(baseline in evaluation_results.algorithm.unique() for baseline in args.baselines)
assert "predictions" in evaluation_results.rand_setting.unique()
if len(args.randomization_mode) > 0:
for rand_setting in args.randomization_mode:
assert any(
setting.startswith(f"randomize-{rand_setting}") for setting in evaluation_results.rand_setting.unique()
)
if args.n_trials_robustness > 0:
assert any(
setting.startswith(f"robustness-{args.n_trials_robustness}")
for setting in evaluation_results.rand_setting.unique()
)
assert all(test_mode in evaluation_results.LPO_LCO_LDO.unique() for test_mode in args.test_mode)
assert evaluation_results.CV_split.astype(int).max() == (args.n_cv_splits - 1)
assert evaluation_results.Pearson.astype(float).max() > 0.5