Skip to content

Commit b2e5c25

Browse files
authored
Merge pull request #482 from The-Strategy-Unit/alter_run_all
allows run_all to use different data sources
2 parents 4688f3e + f8d2837 commit b2e5c25

File tree

7 files changed

+36
-25
lines changed

7 files changed

+36
-25
lines changed

src/nhp/docker/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import argparse
44
import logging
5-
import os
65

76
from nhp.docker.config import Config
87
from nhp.docker.run import RunWithAzureStorage, RunWithLocalStorage
8+
from nhp.model.data import Local
99
from nhp.model.run import run_all
1010

1111

@@ -54,7 +54,10 @@ def main(config: Config = Config()):
5454
logging.info("app_version: %s", runner.params["app_version"])
5555

5656
saved_files, results_file = run_all(
57-
runner.params, "data", runner.progress_callback(), args.save_full_model_results
57+
runner.params,
58+
Local.create("data"),
59+
runner.progress_callback(),
60+
args.save_full_model_results,
5861
)
5962

6063
runner.finish(results_file, saved_files, args.save_full_model_results)

src/nhp/model/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from nhp.model.aae import AaEModel
19+
from nhp.model.data import Local
1920
from nhp.model.helpers import load_params
2021
from nhp.model.inpatients import InpatientsModel
2122
from nhp.model.outpatients import OutpatientsModel
@@ -65,7 +66,7 @@ def main() -> None:
6566

6667
run_all(
6768
params,
68-
args.data_path,
69+
Local.create(args.data_path),
6970
lambda _: lambda _: None,
7071
args.save_full_model_results,
7172
)

src/nhp/model/run.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def noop_progress_callback(_: Any) -> Callable[[Any], None]:
111111

112112
def run_all(
113113
params: dict,
114-
data_path: str,
114+
nhp_data: Callable[[int, str], Data],
115115
progress_callback: Callable[[Any], Callable[[Any], None]] = noop_progress_callback,
116116
save_full_model_results: bool = False,
117117
) -> Tuple[list, str]:
@@ -121,8 +121,8 @@ def run_all(
121121
122122
:param params: the parameters to use for this model run
123123
:type params: dict
124-
:param data_path: where the data is stored
125-
:type data_path: str
124+
:param nhp_data: the Data class to use for loading data
125+
:type nhp_data: Callable[[int, str], Data]
126126
:param progress_callback: a callback function for updating progress.
127127
Defaults to noop_progress_callback.
128128
:type progress_callback: Callable[[str], Callable[[Any], Any]]
@@ -134,8 +134,6 @@ def run_all(
134134
model_types = [InpatientsModel, OutpatientsModel, AaEModel]
135135
run_params = Model.generate_run_params(params)
136136

137-
nhp_data = Local.create(data_path)
138-
139137
# set the data path in the HealthStatusAdjustment class
140138
hsa = HealthStatusAdjustmentInterpolated(
141139
nhp_data(params["start_year"], params["dataset"]), params["start_year"]

tests/integration/nhp/model/test_run_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ def test_all_model_runs(params_path, data_dir):
6161
params = load_params(params_path)
6262
params["model_runs"] = 4
6363

64+
nhp_data = Local.create(data_dir)
6465
# act
65-
actual = run_all(params, data_dir)
66+
actual = run_all(params, nhp_data)
6667

6768
# assert
6869
res_path = "results/synthetic/test/20220101_000000"

tests/unit/nhp/docker/test___main__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def test_main_local(mocker):
4646
rwls = mocker.patch("nhp.docker.__main__.RunWithLocalStorage")
4747
rwas = mocker.patch("nhp.docker.__main__.RunWithAzureStorage")
4848

49+
local_data_mock = mocker.patch("nhp.docker.__main__.Local")
50+
local_data_mock.create.return_value = "data"
51+
4952
params = {
5053
"model_runs": 256,
5154
"start_year": 2019,
@@ -71,6 +74,8 @@ def test_main_local(mocker):
7174
ru_m.assert_called_once_with(params, "data", s.progress_callback(), False)
7275
s.finish.assert_called_once_with("results.json", "list_of_results", False)
7376

77+
local_data_mock.create.assert_called_once_with("data")
78+
7479

7580
def test_main_azure(mocker):
7681
# arrange
@@ -82,6 +87,9 @@ def test_main_azure(mocker):
8287
rwls = mocker.patch("nhp.docker.__main__.RunWithLocalStorage")
8388
rwas = mocker.patch("nhp.docker.__main__.RunWithAzureStorage")
8489

90+
local_data_mock = mocker.patch("nhp.docker.__main__.Local")
91+
local_data_mock.create.return_value = "data"
92+
8593
config = Mock()
8694
config.APP_VERSION = "dev"
8795
config.DATA_VERSION = "dev"
@@ -112,6 +120,8 @@ def test_main_azure(mocker):
112120
ru_m.assert_called_once_with(params, "data", s.progress_callback(), False)
113121
s.finish.assert_called_once_with("results.json", "list_of_results", False)
114122

123+
local_data_mock.create.assert_called_once_with("data")
124+
115125

116126
def test_init(mocker):
117127
"""It should run the main method if __name__ is __main__."""

tests/unit/nhp/model/test__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def test_main_all_runs(mocker):
6868
args.save_full_model_results = False
6969
mocker.patch("nhp.model.__main__._parse_args", return_value=args)
7070
ldp_mock = mocker.patch("nhp.model.__main__.load_params", return_value="params")
71+
local_data_mock = mocker.patch("nhp.model.__main__.Local")
72+
local_data_mock.create.return_value = "data"
7173

7274
run_all_mock = mocker.patch("nhp.model.__main__.run_all")
7375
run_single_mock = mocker.patch("nhp.model.__main__.run_single_model_run")
@@ -83,6 +85,7 @@ def test_main_all_runs(mocker):
8385

8486
run_single_mock.assert_not_called()
8587
ldp_mock.assert_called_once_with("queue/params.json")
88+
local_data_mock.create.assert_called_once_with("data")
8689

8790

8891
def test_init(mocker):

tests/unit/nhp/model/test_run.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def test_run_all(mocker):
8787
)
8888
gr_m = mocker.patch("nhp.model.run.generate_results_json", return_value="results_json_path")
8989
sr_m = mocker.patch("nhp.model.run.save_results_files", return_value="results_paths")
90-
nd_m = mocker.patch("nhp.model.run.Local")
9190

9291
pc_m = Mock()
9392
pc_m().return_value = "progress callback"
@@ -102,16 +101,15 @@ def test_run_all(mocker):
102101
"model_runs": 10,
103102
"create_datetime": "20230123_012345",
104103
}
104+
data_mock = Mock(return_value="nhp_data")
105105

106106
# act
107-
actual = run_all(params, "data_path", pc_m, False)
107+
actual = run_all(params, data_mock, pc_m, False)
108108

109109
# assert
110110
assert actual == ("results_paths", "results_json_path")
111111

112-
nd_m.create.assert_called_once_with("data_path")
113-
nd_c = nd_m.create()
114-
nd_c.assert_called_once_with(2020, "synthetic")
112+
data_mock.assert_called_once_with(2020, "synthetic")
115113

116114
assert pc_m.call_args_list == [
117115
call("Inpatients"),
@@ -120,13 +118,13 @@ def test_run_all(mocker):
120118
]
121119

122120
grp_m.assert_called_once_with(params)
123-
hsa_m.assert_called_once_with(nd_c(2020, "synthetic"), 2020)
121+
hsa_m.assert_called_once_with("nhp_data", 2020)
124122

125123
assert rm_m.call_args_list == [
126124
call(
127125
m,
128126
params,
129-
nd_c,
127+
data_mock,
130128
"hsa",
131129
{"variant": "variants"},
132130
pc_m(),
@@ -136,15 +134,12 @@ def test_run_all(mocker):
136134
]
137135

138136
cr_m.assert_called_once_with(["ip", "op", "aae"])
139-
# weaker form of checking, but as we intended to drop this function in the future don't expend
140-
# effort to fixing this part of the test
141-
gr_m.assert_called_once()
142-
# gr_m.assert_called_once_with(
143-
# {"default": "combined_results"},
144-
# "combined_step_counts",
145-
# params,
146-
# {"variant": "variants"},
147-
# )
137+
gr_m.assert_called_once_with(
138+
{"default": "combined_results", "step_counts": "combined_step_counts"},
139+
"combined_step_counts",
140+
params,
141+
{"variant": "variants"},
142+
)
148143
sr_m.assert_called_once_with(
149144
{"default": "combined_results", "step_counts": "combined_step_counts"}, params
150145
)

0 commit comments

Comments
 (0)