Skip to content

Commit 9081625

Browse files
authored
Merge pull request #57 from The-Strategy-Unit/add_results_tracking
editing trial results to align with target trial output. Closes #50.
2 parents 999aded + 7ed784a commit 9081625

File tree

3 files changed

+149
-90
lines changed

3 files changed

+149
-90
lines changed

renal_capacity_model/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ class Config:
1212
def __init__(self, config_dict={}):
1313
self.trace = config_dict.get("trace", False)
1414
self.number_of_runs = config_dict.get("number_of_runs", 10)
15-
self.sim_duration = config_dict.get("sim_duration", 1000)
15+
self.sim_duration = config_dict.get("sim_duration", int(2*365)) # in days, but should be a multiple of 365 i.e. years
1616
self.random_seed = config_dict.get("random_seed", 0)
1717
self.arrival_rate = config_dict.get("arrival_rate", 1)
18+
self.snapshot_interval = config_dict.get("snapshot_interval", int(365)) # how often to take a snapshot of the results_df
1819

1920
# distributions for calculating interarrival times
2021
self.age_dist = config_dict.get(

renal_capacity_model/model.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ def __init__(self, run_number, rng, config):
3131
self.inter_arrival_times = get_interarrival_times(self.config)
3232
self.patients_in_system = {k: 0 for k in self.inter_arrival_times.keys()}
3333
self.results_df = self._setup_results_df()
34+
self.snapshot_results_df = self._setup_snapshot_df()
35+
self.snapshot_interval = self.config.snapshot_interval # how often to take a snapshot of the results_df
3436

3537
def _setup_results_df(self):
38+
3639
"""Sets up DataFrame for recording model results
3740
3841
Returns:
@@ -45,21 +48,63 @@ def _setup_results_df(self):
4548
"entry_time",
4649
"diverted_to_con_care",
4750
"suitable_for_transplant",
48-
"live_transplant_count", ## should this instead count number of live transplants per patient?
49-
"cadaver_transplant_count", ## should this instead count number of cadaver transplants per patient?
51+
"live_transplant_count",
52+
"cadaver_transplant_count",
5053
"pre_emptive_transplant",
5154
"transplant_count",
52-
"ichd_dialysis_count", ## this is what we'll use to track the number in ichd over time
55+
"ichd_dialysis_count",
5356
"hhd_dialysis_count",
5457
"pd_dialysis_count",
5558
"time_of_death",
59+
"death_from_con_care",
60+
"death_from_ichd",
61+
"death_from_hhd",
62+
"death_from_pd",
63+
"death_post_live_transplant",
64+
"death_post_cadaver_transplant",
5665
]
5766
)
5867
results_df["patient ID"] = [1]
5968
results_df.set_index("patient ID", inplace=True)
6069

6170
return results_df
6271

72+
def _setup_snapshot_df(self):
73+
74+
"""Sets up DataFrame for recording snapshot model results
75+
76+
Returns:
77+
pd.DataFrame: Empty DataFrame for recording model results
78+
"""
79+
snapshot_results_df = pd.DataFrame(
80+
columns=[
81+
"snapshot_time",
82+
"age_group",
83+
"referral_type",
84+
"entry_time",
85+
"diverted_to_con_care",
86+
"suitable_for_transplant",
87+
"live_transplant_count",
88+
"cadaver_transplant_count",
89+
"pre_emptive_transplant",
90+
"transplant_count",
91+
"ichd_dialysis_count",
92+
"hhd_dialysis_count",
93+
"pd_dialysis_count",
94+
"time_of_death",
95+
"death_from_con_care",
96+
"death_from_ichd",
97+
"death_from_hhd",
98+
"death_from_pd",
99+
"death_post_live_transplant",
100+
"death_post_cadaver_transplant",
101+
]
102+
)
103+
snapshot_results_df["patient ID"] = [1]
104+
snapshot_results_df.set_index("patient ID", inplace=True)
105+
106+
return snapshot_results_df
107+
63108
def generator_patient_arrivals(self, patient_type):
64109
"""Generator function for arriving patients
65110
@@ -73,11 +118,21 @@ def generator_patient_arrivals(self, patient_type):
73118
self.patient_counter += 1
74119

75120
p = Patient(self.patient_counter, patient_type)
76-
start_time_in_system_patient = self.rng.exponential(
77-
1 / self.inter_arrival_times[patient_type]
78-
) # self.env.now
79-
p.last_dialysis_modality = "none"
80-
p.transplant_count = 0
121+
122+
123+
if self.patient_counter <= 12:
124+
start_time_in_system_patient = self.rng.exponential(
125+
1 / self.inter_arrival_times[patient_type]
126+
)
127+
yield self.env.timeout(start_time_in_system_patient)
128+
else:
129+
start_time_in_system_patient = self.env.now
130+
131+
self.patients_in_system[patient_type] += 1
132+
133+
p.last_dialysis_modality = "none"
134+
p.transplant_count = 0
135+
81136
self.results_df.loc[p.id, "entry_time"] = start_time_in_system_patient
82137
self.results_df.loc[p.id, "age_group"] = int(p.age_group)
83138
self.results_df.loc[p.id, "referral_type"] = p.referral_type
@@ -90,19 +145,19 @@ def generator_patient_arrivals(self, patient_type):
90145

91146
if self.rng.uniform(0, 1) > self.config.con_care_dist[p.age_group]:
92147
# If the patient is not diverted to conservative care they start KRT
93-
self.patients_in_system[patient_type] += 1
94148
self.env.process(self.start_krt(p))
95149
else:
96150
# these patients are diverted to conservative care. We don't need a process here as all these patients do is wait a while before leaving the system
97151
self.results_df.loc[p.id, "diverted_to_con_care"] = True
98-
yield self.env.timeout(start_time_in_system_patient)
99152
sampled_con_care_time = (
100153
self.config.ttd_con_care_scale
101154
* self.rng.weibull(a=self.config.ttd_con_care_shape, size=1)
102155
)
103156
yield self.env.timeout(sampled_con_care_time)
104157
self.results_df.loc[p.id, "time_of_death"] = self.env.now
105158
self.patients_in_system[patient_type] -= 1
159+
self.results_df.loc[p.id, "diverted_to_con_care"] = False # as they've left conservative care
160+
self.results_df.loc[p.id, "death_from_con_care"] = True
106161
if self.config.trace:
107162
print(
108163
f"Patient {p.id} of age group {p.age_group} diverted to conservative care and left the system after {sampled_con_care_time} time units."
@@ -302,6 +357,7 @@ def start_transplant(self, patient):
302357
self.results_df.loc[patient.id, "live_transplant_count"] -= 1
303358
self.patients_in_system[patient.patient_type] -= 1
304359
self.results_df.loc[patient.id, "time_of_death"] = self.env.now
360+
self.results_df.loc[patient.id, "death_post_live_transplant"] = True
305361
if self.config.trace:
306362
print(
307363
f"Patient {patient.id} of age group {patient.age_group} died after live transplant at time {self.env.now}."
@@ -337,6 +393,7 @@ def start_transplant(self, patient):
337393
self.results_df.loc[patient.id, "cadaver_transplant_count"] -= 1
338394
self.patients_in_system[patient.patient_type] -= 1
339395
self.results_df.loc[patient.id, "time_of_death"] = self.env.now
396+
self.results_df.loc[patient.id, "death_post_cadaver_transplant"] = True
340397
if self.config.trace:
341398
print(
342399
f"Patient {patient.id} of age group {patient.age_group} died after cadaver transplant at time {self.env.now}."
@@ -455,6 +512,7 @@ def start_ichd(self, patient):
455512
self.patients_in_system[patient.patient_type] -= 1
456513
self.results_df.loc[patient.id, "ichd_dialysis_count"] -= 1
457514
self.results_df.loc[patient.id, "time_of_death"] = self.env.now
515+
self.results_df.loc[patient.id, "death_from_ichd"] = True
458516
if self.config.trace:
459517
print(
460518
f"Patient {patient.id} of age group {patient.age_group} died and left the system at time {self.env.now}."
@@ -466,6 +524,7 @@ def start_ichd(self, patient):
466524
patient.time_on_ichd_dialysis = sampled_ichd_time
467525
self.patients_in_system[patient.patient_type] -= 1
468526
self.results_df.loc[patient.id, "ichd_dialysis_count"] -= 1
527+
self.results_df.loc[patient.id, "death_from_ichd"] = True
469528
self.results_df.loc[patient.id, "time_of_death"] = self.env.now
470529
if self.config.trace:
471530
print(
@@ -546,26 +605,31 @@ def start_pd(self, patient):
546605
yield self.env.timeout(5)
547606
patient.last_dialysis_modality = "pd"
548607

549-
def calculate_run_results(self):
550-
# TODO: what do we want to count?
551-
pass
608+
def snapshot_results(self):
609+
while True:
610+
self.snapshot_results_df = pd.concat([self.snapshot_results_df, self.results_df.assign(snapshot_time=self.env.now)])
611+
if self.config.trace:
612+
print(f"Taking results snapshot of the results_df at time {self.env.now}")
613+
yield self.env.timeout(self.snapshot_interval)
552614

553615
def run(self):
554616
"""Runs the model"""
555617
# We set up a generator for each of the patient types we have an IAT for
556618
for patient_type in self.inter_arrival_times.keys():
557619
self.env.process(self.generator_patient_arrivals(patient_type))
620+
621+
self.env.process(self.snapshot_results())
558622

559623
self.env.run(until=self.config.sim_duration)
560624

561-
self.calculate_run_results()
625+
#self.calculate_run_results()
562626

563627
# Show results (optional - set in config)
564628
if self.config.trace:
565629
print(f"Run Number {self.run_number}")
566630
print(self.patients_in_system)
567631
print(self.results_df)
568-
# print(test_arrival_processes(self.results_df,self.config))
632+
print(self.snapshot_results_df)
569633

570634

571635
if __name__ == "__main__":

renal_capacity_model/trial.py

Lines changed: 68 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -21,90 +21,84 @@ def __init__(self, config):
2121

2222
def print_trial_results(self):
2323
print("Trial Results")
24-
print(self.df_trial_results.mean())
24+
output_means = self.df_trial_results.mean().to_frame()
25+
output_means['Time']=output_means.index.str.split('_').str[-1]
26+
output_means.index = output_means.index.str.rsplit('_', n=1).str[0]
27+
reshaped_trial_results = output_means.pivot(columns='Time', values=0)
28+
print(reshaped_trial_results)
29+
print(reshaped_trial_results.diff(axis=1)) ### could use for plotting mortality over time instead of cumulative mortality
2530

2631
def setup_trial_results(self):
2732
df_trial_results = pd.DataFrame()
2833
df_trial_results["Run Number"] = [0]
2934
df_trial_results.set_index("Run Number", inplace=True)
3035
return df_trial_results
36+
37+
def process_model_results(self,model,run):
38+
39+
self.df_trial_results.loc[run, "total_entries"] = model.results_df["entry_time"].count()
40+
self.df_trial_results.loc[run, "prevalence_con_care"] = model.results_df["diverted_to_con_care"].sum()
41+
self.df_trial_results.loc[run, "prevalence_ichd"] = model.results_df["ichd_dialysis_count"].sum()
42+
self.df_trial_results.loc[run, "prevalence_hhd"] = model.results_df["hhd_dialysis_count"].sum()
43+
self.df_trial_results.loc[run, "prevalence_pd"] = model.results_df["pd_dialysis_count"].sum()
44+
self.df_trial_results.loc[run, "prevalence_live_Tx"] = model.results_df["live_transplant_count"].sum()
45+
self.df_trial_results.loc[run, "prevalence_cadaver_Tx"] = model.results_df["cadaver_transplant_count"].sum()
46+
47+
self.df_trial_results.loc[run, "total_deaths"] = model.results_df["time_of_death"].count()
48+
self.df_trial_results.loc[run, "mortality_con_care"] = model.results_df["death_from_con_care"].sum()
49+
self.df_trial_results.loc[run, "mortality_ichd"] = model.results_df["death_from_ichd"].sum()
50+
self.df_trial_results.loc[run, "mortality_hhd"] = model.results_df["death_from_hhd"].sum()
51+
self.df_trial_results.loc[run, "mortality_pd"] = model.results_df["death_from_pd"].sum()
52+
self.df_trial_results.loc[run, "mortality_live_Tx"] = model.results_df["death_post_live_transplant"].sum()
53+
self.df_trial_results.loc[run, "mortality_cadaver_Tx"] = model.results_df["death_post_cadaver_transplant"].sum()
54+
55+
def process_snapshot_results(self,model,run):
56+
## this groups the results by the time the snapshot was taken, so we can see how prevalence and mortality change over time
57+
results_grouped_by_time = (
58+
model.snapshot_results_df.groupby("snapshot_time")
59+
.agg(
60+
{
61+
"entry_time": "count",
62+
"diverted_to_con_care": "sum",
63+
"ichd_dialysis_count": "sum",
64+
"hhd_dialysis_count": "sum",
65+
"pd_dialysis_count": "sum",
66+
"live_transplant_count": "sum",
67+
"cadaver_transplant_count": "sum",
68+
"time_of_death": "count",
69+
"death_from_con_care": "sum",
70+
"death_from_ichd": "sum",
71+
"death_from_hhd": "sum",
72+
"death_from_pd": "sum",
73+
"death_post_live_transplant": "count",
74+
"death_post_cadaver_transplant": "count",
75+
}
76+
)
77+
.rename(columns={"entry_time": "total_entries","time_of_death": "total_deaths"})
78+
)
79+
80+
for snapshot_time in results_grouped_by_time.index:
81+
self.df_trial_results.loc[run, f"total_entries_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "total_entries"]
82+
self.df_trial_results.loc[run, f"prevalence_con_care_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "diverted_to_con_care"]
83+
self.df_trial_results.loc[run, f"prevalence_ichd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "ichd_dialysis_count"]
84+
self.df_trial_results.loc[run, f"prevalence_hhd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "hhd_dialysis_count"]
85+
self.df_trial_results.loc[run, f"prevalence_pd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "pd_dialysis_count"]
86+
self.df_trial_results.loc[run, f"prevalence_live_Tx_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "live_transplant_count"]
87+
self.df_trial_results.loc[run, f"prevalence_cadaver_Tx_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "cadaver_transplant_count"]
88+
self.df_trial_results.loc[run, f"total_deaths_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "total_deaths"]
89+
self.df_trial_results.loc[run, f"mortality_con_care_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_from_con_care"]
90+
self.df_trial_results.loc[run, f"mortality_ichd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_from_ichd"]
91+
self.df_trial_results.loc[run, f"mortality_hhd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_from_hhd"]
92+
self.df_trial_results.loc[run, f"mortality_pd_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_from_pd"]
93+
self.df_trial_results.loc[run, f"mortality_live_Tx_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_post_live_transplant"]
94+
self.df_trial_results.loc[run, f"mortality_cadaver_Tx_{snapshot_time}"] = results_grouped_by_time.loc[snapshot_time, "death_post_cadaver_transplant"]
3195

3296
def run_trial(self):
3397
for run in range(self.config.number_of_runs):
3498
model = Model(run, self.rng, self.config)
3599
model.run()
36-
# Process results. Consider moving to separate function if it gets too complex
37-
results_grouped_by_age = (
38-
model.results_df.groupby("age_group")
39-
.agg(
40-
{
41-
"diverted_to_con_care": "sum",
42-
"entry_time": "count",
43-
"suitable_for_transplant": "sum",
44-
"live_transplant_count": "sum",
45-
"cadaver_transplant_count": "sum",
46-
"pre_emptive_transplant": "sum",
47-
}
48-
)
49-
.rename(columns={"entry_time": "total_entries"})
50-
)
51-
self.df_trial_results.loc[run, "total_entries"] = results_grouped_by_age[
52-
"total_entries"
53-
].sum()
54-
self.df_trial_results.loc[run, "diverted_to_con_care"] = (
55-
results_grouped_by_age["diverted_to_con_care"].sum()
56-
)
57-
for age_group in results_grouped_by_age.index:
58-
self.df_trial_results.loc[
59-
run, f"diverted_to_con_care_{int(age_group)}"
60-
] = (
61-
results_grouped_by_age.loc[age_group, "diverted_to_con_care"]
62-
/ results_grouped_by_age.loc[age_group, "total_entries"]
63-
)
64-
65-
self.df_trial_results.loc[run, "suitable_for_transplant"] = (
66-
results_grouped_by_age["suitable_for_transplant"].sum()
67-
)
68-
self.df_trial_results.loc[run, "proportion_suitable_for_transplant"] = (
69-
results_grouped_by_age["suitable_for_transplant"].sum()
70-
/ results_grouped_by_age["total_entries"].sum()
71-
)
72-
for age_group in results_grouped_by_age.index:
73-
self.df_trial_results.loc[
74-
run, f"suitable_for_transplant_{int(age_group)}"
75-
] = (
76-
results_grouped_by_age.loc[age_group, "suitable_for_transplant"]
77-
/ results_grouped_by_age.loc[age_group, "total_entries"]
78-
)
79-
self.df_trial_results.loc[run, "pre_emptive_transplant"] = (
80-
results_grouped_by_age["pre_emptive_transplant"].sum()
81-
)
82-
self.df_trial_results.loc[run, "proportion_pre_emptive_transplant"] = (
83-
results_grouped_by_age["pre_emptive_transplant"].sum()
84-
/ results_grouped_by_age["total_entries"].sum()
85-
)
86-
for age_group in results_grouped_by_age.index:
87-
self.df_trial_results.loc[
88-
run, f"pre_emptive_transplant_{int(age_group)}"
89-
] = (
90-
results_grouped_by_age.loc[age_group, "pre_emptive_transplant"]
91-
/ results_grouped_by_age.loc[age_group, "total_entries"]
92-
)
93-
self.df_trial_results.loc[
94-
run, "proportion_live_transplant"
95-
] = results_grouped_by_age["live_transplant_count"].sum() / (
96-
results_grouped_by_age["live_transplant_count"].sum()
97-
+ results_grouped_by_age["cadaver_transplant_count"].sum()
98-
)
99-
for age_group in results_grouped_by_age.index:
100-
self.df_trial_results.loc[run, f"live_transplants_{int(age_group)}"] = (
101-
results_grouped_by_age.loc[age_group, "live_transplant_count"]
102-
)
103-
for age_group in results_grouped_by_age.index:
104-
self.df_trial_results.loc[
105-
run, f"cadaver_transplants_{int(age_group)}"
106-
] = results_grouped_by_age.loc[age_group, "cadaver_transplant_count"]
107-
for k, v in model.patients_in_system.items():
108-
self.df_trial_results.loc[run, k] = v
100+
101+
model.snapshot_results_df = pd.concat([model.snapshot_results_df, model.results_df.assign(snapshot_time=model.config.sim_duration)])
102+
self.process_snapshot_results(model,run)
109103

110104
self.print_trial_results()

0 commit comments

Comments
 (0)