Skip to content

Commit 4688f3e

Browse files
authored
Merge pull request #478 from The-Strategy-Unit/add_delivery_episode_in_spell_aggregation
2 parents 8c0b630 + 2595d30 commit 4688f3e

File tree

11 files changed

+190
-136
lines changed

11 files changed

+190
-136
lines changed

src/nhp/model/aae.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -158,28 +158,18 @@ def process_results(data: pd.DataFrame) -> pd.DataFrame:
158158
)
159159
return data
160160

161-
def aggregate(self, model_iteration: ModelIteration) -> tuple[pd.DataFrame, list[list[str]]]:
162-
"""Aggregate the model results.
161+
def specific_aggregations(self, model_results: pd.DataFrame) -> dict[str, pd.Series]:
162+
"""Create other aggregations specific to the model type.
163163
164-
Can also be used to aggregate the baseline data by passing in a `ModelIteration` with
165-
the `model_run` argument set `-1`.
166-
167-
:param model_iteration: an instance of the `ModelIteration` class
168-
:type model_iteration: model.model_iteration.ModelIteration
169-
170-
:returns: a tuple containing the model results, and a list of lists which contain the
171-
aggregations to perform
172-
:rtype: tuple[pd.DataFrame, list[list[str]]]
164+
:param model_results: the results of a model run
165+
:type model_results: pd.DataFrame
166+
:return: dictionary containing the specific aggregations
167+
:rtype: dict[str, pd.Series]
173168
"""
174-
model_results = self.process_results(model_iteration.get_model_results())
175-
176-
return (
177-
model_results,
178-
[
179-
["acuity"],
180-
["attendance_category"],
181-
],
182-
)
169+
return {
170+
"acuity": self.get_agg(model_results, "acuity"),
171+
"attendance_category": self.get_agg(model_results, "attendance_category"),
172+
}
183173

184174
def calculate_avoided_activity(
185175
self, data: pd.DataFrame, data_resampled: pd.DataFrame

src/nhp/model/inpatients.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def process_results(data: pd.DataFrame) -> pd.DataFrame:
207207
"tretspef",
208208
"tretspef_grouped",
209209
"los_group",
210+
"maternity_delivery_in_spell",
210211
],
211212
dropna=False,
212213
)[["admissions", "beddays", "procedures"]]
@@ -241,29 +242,22 @@ def process_results(data: pd.DataFrame) -> pd.DataFrame:
241242

242243
return data
243244

244-
def aggregate(self, model_iteration: ModelIteration) -> tuple[pd.DataFrame, list[list[str]]]:
245-
"""Aggregate the model results.
245+
def specific_aggregations(self, model_results: pd.DataFrame) -> dict[str, pd.Series]:
246+
"""Create other aggregations specific to the model type.
246247
247-
Can also be used to aggregate the baseline data by passing in a `ModelIteration` with
248-
the `model_run` argument set `-1`.
249-
250-
:param model_iteration: an instance of the `ModelIteration` class
251-
:type model_iteration: model.model_iteration.ModelIteration
252-
253-
:returns: a tuple containing the model results, and a list of lists which contain the
254-
aggregations to perform
255-
:rtype: tuple[pd.DataFrame, list[list[str]]]
248+
:param model_results: the results of a model run
249+
:type model_results: pd.DataFrame
250+
:return: dictionary containing the specific aggregations
251+
:rtype: dict[str, pd.Series]
256252
"""
257-
model_results = self.process_results(model_iteration.get_model_results())
258-
259-
return (
260-
model_results,
261-
[
262-
["sex", "tretspef_grouped"],
263-
["tretspef"],
264-
["tretspef", "los_group"],
265-
],
266-
)
253+
return {
254+
"sex+tretspef_grouped": self.get_agg(model_results, "sex", "tretspef_grouped"),
255+
"tretspef": self.get_agg(model_results, "tretspef"),
256+
"tretspef+los_group": self.get_agg(model_results, "tretspef", "los_group"),
257+
"delivery_episode_in_spell": self.get_agg(
258+
model_results[model_results["maternity_delivery_in_spell"]]
259+
),
260+
}
267261

268262
def calculate_avoided_activity(
269263
self, data: pd.DataFrame, data_resampled: pd.DataFrame

src/nhp/model/model.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,46 @@ def apply_resampling(self, row_samples: np.ndarray, data: pd.DataFrame) -> pd.Da
427427
:rtype: pd.DataFrame
428428
"""
429429
raise NotImplementedError()
430+
431+
def aggregate(self, model_iteration: ModelIteration) -> dict[str, pd.Series]:
432+
"""Aggregate the model results.
433+
434+
Can also be used to aggregate the baseline data by passing in a `ModelIteration` with
435+
the `model_run` argument set `-1`.
436+
437+
:param model_iteration: an instance of the `ModelIteration` class
438+
:type model_iteration: model.model_iteration.ModelIteration
439+
440+
:returns: a tuple containing the model results, and a list of lists which contain the
441+
aggregations to perform
442+
:rtype: tuple[pd.DataFrame, list[list[str]]]
443+
"""
444+
model_results = self.process_results(model_iteration.get_model_results())
445+
446+
base_aggregations = {
447+
"default": self.get_agg(model_results),
448+
"sex+age_group": self.get_agg(model_results, "sex", "age_group"),
449+
"age": self.get_agg(model_results, "age"),
450+
}
451+
452+
return {**base_aggregations, **self.specific_aggregations(model_results)}
453+
454+
def process_results(self, data: pd.DataFrame) -> pd.DataFrame:
455+
"""Processes the data into a format suitable for aggregation in results files.
456+
457+
:param data: Data to be processed. Format should be similar to Model.data
458+
:type data: pd.DataFrame
459+
:return: Processed results
460+
:rtype: pd.DataFrame
461+
"""
462+
raise NotImplementedError()
463+
464+
def specific_aggregations(self, model_results: pd.DataFrame) -> dict[str, pd.Series]:
465+
"""Create other aggregations specific to the model type.
466+
467+
:param model_results: the results of a model run
468+
:type model_results: pd.DataFrame
469+
:return: dictionary containing the specific aggregations
470+
:rtype: dict[str, pd.Series]
471+
"""
472+
raise NotImplementedError()

src/nhp/model/model_iteration.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,15 @@ def get_aggregate_results(self) -> ModelRunResult:
165165
:returns: a tuple containing a dictionary of results, and the step counts
166166
:rtype: tuple[dict[str, pd.Series], pd.Series | None]:
167167
"""
168-
model_results, aggregations = self.model.aggregate(self)
169-
170-
aggs = {
171-
"default" if not v else "+".join(v): self.model.get_agg(model_results, *v)
172-
for v in [[], ["sex", "age_group"], ["age"], *aggregations]
173-
}
168+
aggregations = self.model.aggregate(self)
174169

175170
if not self.avoided_activity.empty:
176171
avoided_activity_agg = self.model.process_results(self.avoided_activity)
177-
aggs["avoided_activity"] = self.model.get_agg(avoided_activity_agg, "sex", "age_group")
172+
aggregations["avoided_activity"] = self.model.get_agg(
173+
avoided_activity_agg, "sex", "age_group"
174+
)
178175

179-
return aggs, self.get_step_counts()
176+
return aggregations, self.get_step_counts()
180177

181178
def get_step_counts(self) -> pd.Series | None:
182179
"""Get the step counts of a model run."""

src/nhp/model/outpatients.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -228,28 +228,18 @@ def process_results(data: pd.DataFrame) -> pd.DataFrame:
228228
)
229229
return data
230230

231-
def aggregate(self, model_iteration: ModelIteration) -> tuple[pd.DataFrame, list[list[str]]]:
232-
"""Aggregate the model results.
231+
def specific_aggregations(self, model_results: pd.DataFrame) -> dict[str, pd.Series]:
232+
"""Create other aggregations specific to the model type.
233233
234-
Can also be used to aggregate the baseline data by passing in a `ModelIteration` with
235-
the `model_run` argument set `-1`.
236-
237-
:param model_iteration: an instance of the `ModelIteration` class
238-
:type model_iteration: model.model_iteration.ModelIteration
239-
240-
:returns: a tuple containing the model results, and a list of lists which contain the
241-
aggregations to perform
242-
:rtype: tuple[pd.DataFrame, list[list[str]]]
234+
:param model_results: the results of a model run
235+
:type model_results: pd.DataFrame
236+
:return: dictionary containing the specific aggregations
237+
:rtype: dict[str, pd.Series]
243238
"""
244-
model_results = self.process_results(model_iteration.get_model_results())
245-
246-
return (
247-
model_results,
248-
[
249-
["sex", "tretspef_grouped"],
250-
["tretspef"],
251-
],
252-
)
239+
return {
240+
"sex+tretspef_grouped": self.get_agg(model_results, "sex", "tretspef_grouped"),
241+
"tretspef": self.get_agg(model_results, "tretspef"),
242+
}
253243

254244
def save_results(self, model_iteration: ModelIteration, path_fn: Callable[[str], str]) -> None:
255245
"""Save the results of running the model.

tests/integration/nhp/model/test_run_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313
[
1414
(
1515
InpatientsModel,
16-
{
17-
"sex+tretspef_grouped",
18-
"tretspef",
19-
"tretspef+los_group",
20-
},
16+
{"sex+tretspef_grouped", "tretspef", "tretspef+los_group", "delivery_episode_in_spell"},
2117
),
2218
(
2319
OutpatientsModel,
@@ -79,6 +75,7 @@ def test_all_model_runs(params_path, data_dir):
7975
"attendance_category",
8076
"avoided_activity",
8177
"default",
78+
"delivery_episode_in_spell",
8279
"sex+age_group",
8380
"sex+tretspef_grouped",
8481
"tretspef",

tests/unit/nhp/model/test_aae.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,30 +157,25 @@ def test_efficiencies(mock_model):
157157
assert actual == ("data", None)
158158

159159

160-
def test_aggregate(mock_model):
160+
def test_specific_aggregations(mocker, mock_model):
161161
"""Test that it aggregates the results correctly."""
162-
163162
# arrange
164-
def create_agg_stub(model_results, cols=None):
165-
name = "+".join(cols) if cols else "default"
166-
return {name: model_results.to_dict(orient="list")}
163+
m = mocker.patch("nhp.model.AaEModel.get_agg", return_value="agg_data")
167164

168165
mdl = mock_model
169-
mdl._create_agg = Mock(wraps=create_agg_stub)
170-
mdl.process_results = Mock(return_value="processed_data")
171-
172-
mr_mock = Mock()
173-
mr_mock.get_model_results.return_value = "model_results"
174166

175167
# act
176-
actual_mr, actual_aggs = mdl.aggregate(mr_mock)
168+
actual = mdl.specific_aggregations("results") # type: ignore
177169

178170
# assert
179-
mdl.process_results.assert_called_once_with("model_results")
180-
assert actual_mr == "processed_data"
181-
assert actual_aggs == [
182-
["acuity"],
183-
["attendance_category"],
171+
assert actual == {
172+
"acuity": "agg_data",
173+
"attendance_category": "agg_data",
174+
}
175+
176+
assert m.call_args_list == [
177+
call("results", "acuity"),
178+
call("results", "attendance_category"),
184179
]
185180

186181

tests/unit/nhp/model/test_inpatients.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import pandas as pd
77
import pytest
8+
from pandas.testing import assert_frame_equal
89

910
from nhp.model.inpatients import InpatientsModel
1011

@@ -239,6 +240,7 @@ def test_process_results(mock_model):
239240
"rn": [1] * 12,
240241
"has_procedure": [0, 1] * 6,
241242
"speldur": list(range(12)),
243+
"maternity_delivery_in_spell": [True, False] * 6,
242244
}
243245
)
244246
df["pod"] = "ip_" + df["group"] + "_admission"
@@ -431,6 +433,21 @@ def test_process_results(mock_model):
431433
"8-14 days",
432434
np.nan,
433435
],
436+
"maternity_delivery_in_spell": [
437+
True,
438+
True,
439+
False,
440+
False,
441+
False,
442+
True,
443+
False,
444+
False,
445+
False,
446+
True,
447+
True,
448+
False,
449+
]
450+
* 2,
434451
"measure": [
435452
"admissions",
436453
"beddays",
@@ -491,32 +508,31 @@ def test_process_results(mock_model):
491508
pd.testing.assert_frame_equal(actual, expected)
492509

493510

494-
def test_aggregate(mock_model):
511+
def test_specific_aggregations(mocker, mock_model):
495512
"""Test that it aggregates the results correctly."""
496-
497513
# arrange
498-
def create_agg_stub(model_results, cols=None):
499-
name = "+".join(cols) if cols else "default"
500-
return {name: model_results.to_dict(orient="list")}
514+
m = mocker.patch("nhp.model.InpatientsModel.get_agg", return_value="agg_data")
501515

502516
mdl = mock_model
503-
mdl._create_agg = Mock(wraps=create_agg_stub)
504-
mdl.process_results = Mock(return_value="processed_data")
505517

506-
mr_mock = Mock()
507-
mr_mock.get_model_results.return_value = "nhp.model_data"
518+
mock_data = pd.DataFrame({"maternity_delivery_in_spell": [True, False], "value": [1, 2]})
508519

509520
# act
510-
actual_mr, actual_aggs = mdl.aggregate(mr_mock)
521+
actual = mdl.specific_aggregations(mock_data)
511522

512523
# assert
524+
assert actual == {
525+
"sex+tretspef_grouped": "agg_data",
526+
"tretspef": "agg_data",
527+
"tretspef+los_group": "agg_data",
528+
"delivery_episode_in_spell": "agg_data",
529+
}
513530

514-
mdl.process_results.assert_called_once_with("nhp.model_data")
515-
assert actual_mr == "processed_data"
516-
assert actual_aggs == [
517-
["sex", "tretspef_grouped"],
518-
["tretspef"],
519-
["tretspef", "los_group"],
531+
assert [(len(i[0][0]), *i[0][1:]) for i in m.call_args_list] == [
532+
(2, "sex", "tretspef_grouped"),
533+
(2, "tretspef"),
534+
(2, "tretspef", "los_group"),
535+
(1,),
520536
]
521537

522538

tests/unit/nhp/model/test_model.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,3 +775,49 @@ def test_apply_resampling(mock_model):
775775
# act & assert
776776
with pytest.raises(NotImplementedError):
777777
mock_model.apply_resampling(None, None)
778+
779+
780+
def test_aggregate(mock_model):
781+
# arrange
782+
mdl = mock_model
783+
mdl.process_results = Mock(return_value="processed_results")
784+
mdl.get_agg = Mock(return_value="agg")
785+
mdl.specific_aggregations = Mock(return_value={"1": "agg", "2": "agg"})
786+
787+
mi_mock = Mock()
788+
mi_mock.get_model_results.return_value = "results"
789+
790+
# act
791+
actual = mdl.aggregate(mi_mock)
792+
793+
# assert
794+
mi_mock.get_model_results.assert_called()
795+
mdl.process_results.assert_called_once_with("results")
796+
assert mdl.get_agg.call_args_list == [
797+
call("processed_results"),
798+
call("processed_results", "sex", "age_group"),
799+
call("processed_results", "age"),
800+
]
801+
mdl.specific_aggregations.assert_called_once_with("processed_results")
802+
803+
assert actual == {
804+
"default": "agg",
805+
"sex+age_group": "agg",
806+
"age": "agg",
807+
"1": "agg",
808+
"2": "agg",
809+
}
810+
811+
812+
def test_process_results(mock_model):
813+
# arrange
814+
# act & assert
815+
with pytest.raises(NotImplementedError):
816+
mock_model.process_results(None)
817+
818+
819+
def test_specific_aggregations(mock_model):
820+
# arrange
821+
# act & assert
822+
with pytest.raises(NotImplementedError):
823+
mock_model.specific_aggregations(None)

0 commit comments

Comments
 (0)