Skip to content

Commit 9b5f075

Browse files
committed
Vectorised preprocessing
1 parent 72724eb commit 9b5f075

File tree

1 file changed

+146
-64
lines changed

1 file changed

+146
-64
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 146 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""This module contains the IPCWEstimator class, for estimating the time to a particular event"""
22

33
import logging
4-
from math import ceil
4+
from numpy import ceil
55
from typing import Any
66
from tqdm import tqdm
7+
from uuid import uuid4
78

89
import numpy as np
910
import pandas as pd
@@ -89,7 +90,7 @@ def __init__(
8990
def add_modelling_assumptions(self):
9091
self.modelling_assumptions.append("The variables in the data vary over time.")
9192

92-
def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligible: pd.Series):
93+
def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligible: pd.Series, time: pd.Series):
9394
"""
9495
Return a binary sequence with each bit representing whether the current
9596
index is the time point at which the individual diverted from the
@@ -98,36 +99,76 @@ def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligib
9899
:param strategy_assigned - the assigned treatment strategy
99100
:param strategy_followed - the strategy followed by the individual
100101
:param eligible - binary sequence represnting the eligibility of the individual at each time step
102+
:param time - The sequence of time steps
101103
"""
102104

103-
strategy_assigned = {t: (var, val) for t, var, val in strategy_assigned}
104-
strategy_followed = {t: (var, val) for t, var, val in strategy_followed}
105+
default = {t: (-1, -1) for t in time.values}
106+
strategy_assigned = default | {t: (var, val) for t, var, val in strategy_assigned}
107+
strategy_followed = default | {t: (var, val) for t, var, val in strategy_followed}
105108

106-
# fill in the gaps
107-
for time in eligible.index:
108-
if time not in strategy_assigned:
109-
strategy_assigned[time] = (-1, -1)
110-
if time not in strategy_followed:
111-
strategy_followed[time] = (-1, -1)
109+
strategy_assigned = sorted([(t, var, val) for t, (var, val) in strategy_assigned.items() if t in time.values])
110+
strategy_followed = sorted([(t, var, val) for t, (var, val) in strategy_followed.items() if t in time.values])
111+
112+
mask = (
113+
pd.Series(strategy_assigned, index=eligible.index) != pd.Series(strategy_followed, index=eligible.index)
114+
).astype("boolean")
115+
mask = mask | ~eligible
116+
mask.reset_index(inplace=True, drop=True)
117+
false = mask.loc[mask]
118+
if false.empty:
119+
return np.zeros(len(mask))
120+
mask = (mask * 1).tolist()
121+
cutoff = false.index[0] + 1
122+
return mask[:cutoff] + ([None] * (len(mask) - cutoff))
123+
124+
def setup_xo_t_do_2(self, individual: pd.DataFrame, strategy_assigned: list):
125+
"""
126+
Return a binary sequence with each bit representing whether the current
127+
index is the time point at which the individual diverted from the
128+
assigned treatment strategy (and thus should be censored).
129+
130+
:param individual: DataFrame representing the individual.
131+
:param strategy_assigned: The assigned treatment strategy.
132+
"""
133+
134+
default = {t: (-1, -1) for t in individual["time"].values}
135+
136+
strategy_assigned = default | {t: (var, val) for t, var, val in strategy_assigned}
137+
strategy_followed = default | {
138+
t: (var, individual.loc[individual["time"] == t, var].values[0])
139+
for t, var, val in self.treatment_strategy
140+
if t in individual["time"].values
141+
}
112142

113143
strategy_assigned = sorted(
114-
[(t, var, val) for t, (var, val) in strategy_assigned.items() if t in eligible.index]
144+
[(t, var, val) for t, (var, val) in strategy_assigned.items() if t in individual["time"].values]
115145
)
116146
strategy_followed = sorted(
117-
[(t, var, val) for t, (var, val) in strategy_followed.items() if t in eligible.index]
147+
[(t, var, val) for t, (var, val) in strategy_followed.items() if t in individual["time"].values]
118148
)
119149

120150
mask = (
121-
pd.Series(strategy_assigned, index=eligible.index) != pd.Series(strategy_followed, index=eligible.index)
151+
pd.Series(strategy_assigned, index=individual.index) != pd.Series(strategy_followed, index=individual.index)
122152
).astype("boolean")
123-
mask = mask | ~eligible
153+
mask = mask | ~individual["eligible"]
124154
mask.reset_index(inplace=True, drop=True)
125155
false = mask.loc[mask]
126156
if false.empty:
127-
return np.zeros(len(mask))
157+
return pd.DataFrame(
158+
{
159+
"id": [str(uuid4())] * len(individual),
160+
"xo_t_do": np.zeros(len(mask)),
161+
}
162+
)
128163
mask = (mask * 1).tolist()
129164
cutoff = false.index[0] + 1
130-
return mask[:cutoff] + ([None] * (len(mask) - cutoff))
165+
166+
return pd.DataFrame(
167+
{
168+
"id": [str(uuid4())] * len(individual),
169+
"xo_t_do": pd.Series(mask[:cutoff] + ([None] * (len(mask) - cutoff)), index=individual.index),
170+
}
171+
)
131172

132173
def setup_fault_t_do(self, individual: pd.DataFrame):
133174
"""
@@ -199,49 +240,96 @@ def preprocess_data(self):
199240
(self.df["time"] % self.timesteps_per_observation == 0) & (self.df["time"] <= self.total_time)
200241
]
201242

202-
individuals = []
203-
new_id = 0
204243
logging.debug(" Preprocessing groups")
205244

206-
for id, individual in tqdm(living_runs.groupby("id", sort=False)):
207-
assert sum(individual["fault_t_do"]) <= 1, (
208-
f"Error initialising fault_t_do for individual\n"
209-
f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
210-
f"with fault at {individual.fault_time.iloc[0]}"
211-
)
212-
213-
strategy_followed = [
214-
[t, var, individual.loc[individual["time"] == t, var].values[0]]
215-
for t, var, val in self.treatment_strategy
216-
if t in individual["time"].values
217-
]
218-
219-
# Control flow:
220-
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
221-
# Individuals that don't start off in either arm are left out
222-
for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
223-
if (
224-
len(strategy_followed) > 0
225-
and strategy_assigned[0] == strategy_followed[0]
226-
and individual.eligible.iloc[0]
227-
):
228-
individual["old_id"] = individual["id"]
229-
individual["id"] = new_id
230-
new_id += 1
231-
individual["trtrand"] = inx
232-
individual["xo_t_do"] = self.setup_xo_t_do(
233-
strategy_assigned, strategy_followed, individual["eligible"]
234-
)
235-
individuals.append(
236-
individual.loc[
237-
individual["time"]
238-
< ceil(individual["fault_time"].iloc[0] / self.timesteps_per_observation)
239-
* self.timesteps_per_observation
240-
].copy()
241-
)
245+
# new
246+
ctrl_time, ctrl_var, ctrl_val = self.control_strategy[0]
247+
control_group = (
248+
living_runs.groupby("id", sort=False)
249+
.filter(lambda gp: len(gp.loc[(gp["time"] == ctrl_time) & (gp[ctrl_var] == ctrl_val)]) > 0)
250+
.copy()
251+
)
252+
control_group["trtrand"] = 0
253+
ctrl_xo_t_do_df = control_group.groupby("id", sort=False).apply(
254+
self.setup_xo_t_do_2, strategy_assigned=self.control_strategy
255+
)
256+
control_group["xo_t_do"] = ctrl_xo_t_do_df["xo_t_do"].values
257+
control_group["old_id"] = control_group["id"]
258+
# control_group["id"] = ctrl_xo_t_do_df["id"].values
259+
control_group["id"] = [f"c-{id}" for id in control_group["id"]]
260+
assert not control_group["id"].isnull().any(), "Null control IDs"
261+
262+
trt_time, trt_var, trt_val = self.treatment_strategy[0]
263+
treatment_group = (
264+
living_runs.groupby("id", sort=False)
265+
.filter(lambda gp: len(gp.loc[(gp["time"] == trt_time) & (gp[trt_var] == trt_val)]) > 0)
266+
.copy()
267+
)
268+
treatment_group["trtrand"] = 1
269+
trt_xo_t_do_df = treatment_group.groupby("id", sort=False).apply(
270+
self.setup_xo_t_do_2, strategy_assigned=self.treatment_strategy
271+
)
272+
treatment_group["xo_t_do"] = trt_xo_t_do_df["xo_t_do"].values
273+
treatment_group["old_id"] = treatment_group["id"]
274+
# treatment_group["id"] = trt_xo_t_do_df["id"].values
275+
treatment_group["id"] = [f"c-{id}" for id in treatment_group["id"]]
276+
assert not treatment_group["id"].isnull().any(), "Null treatment IDs"
277+
278+
individuals = pd.concat([control_group, treatment_group])
279+
individuals = individuals.loc[
280+
individuals["time"]
281+
< ceil(individuals["fault_time"].iloc[0] / self.timesteps_per_observation) * self.timesteps_per_observation
282+
].copy()
283+
284+
individuals.sort_values(by=["old_id", "time"]).to_csv("/home/michael/tmp/vectorised_individuals.csv")
285+
# end new
286+
287+
# individuals = []
288+
#
289+
# for id, individual in tqdm(living_runs.groupby("id", sort=False)):
290+
# assert sum(individual["fault_t_do"]) <= 1, (
291+
# f"Error initialising fault_t_do for individual\n"
292+
# f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
293+
# f"with fault at {individual.fault_time.iloc[0]}"
294+
# )
295+
#
296+
# strategy_followed = [
297+
# [t, var, individual.loc[individual["time"] == t, var].values[0]]
298+
# for t, var, val in self.treatment_strategy
299+
# if t in individual["time"].values
300+
# ]
301+
#
302+
# # Control flow:
303+
# # Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
304+
# # Individuals that don't start off in either arm are left out
305+
# for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
306+
# if (
307+
# len(strategy_followed) > 0
308+
# and strategy_assigned[0] == strategy_followed[0]
309+
# and individual.eligible.iloc[0]
310+
# ):
311+
# individual["old_id"] = individual["id"]
312+
# individual["id"] = new_id
313+
# new_id += 1
314+
# individual["trtrand"] = inx
315+
# individual["xo_t_do"] = self.setup_xo_t_do(
316+
# strategy_assigned, strategy_followed, individual["eligible"], individual["time"]
317+
# )
318+
# individuals.append(
319+
# individual.loc[
320+
# individual["time"]
321+
# < ceil(individual["fault_time"].iloc[0] / self.timesteps_per_observation)
322+
# * self.timesteps_per_observation
323+
# ].copy()
324+
# )
325+
# self.df = pd.concat(individuals)
326+
# self.df.sort_values(by=["id", "time"]).to_csv("/home/michael/tmp/iterated_individuals.csv")
242327
if len(individuals) == 0:
243328
raise ValueError("No individuals followed either strategy.")
244-
self.df = pd.concat(individuals)
329+
self.df = individuals.loc[
330+
individuals["time"]
331+
< ceil(individuals["fault_time"] / self.timesteps_per_observation) * self.timesteps_per_observation
332+
].reset_index()
245333
print(len(individuals), "individuals")
246334

247335
if len(self.df.loc[self.df["trtrand"] == 0]) == 0:
@@ -279,8 +367,8 @@ def estimate_hazard_ratio(self):
279367
)
280368

281369
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
282-
# Estimate the probabilities of remaining un-switched and hence the weights
283-
print("Estimate the probabilities of remaining un-switched and hence the weights")
370+
# Estimate the probabilities of remaining 'un-switched' and hence the weights
371+
print("Estimate the probabilities of remaining 'un-switched' and hence the weights")
284372

285373
preprocessed_data["num"] = 1 - preprocessed_data["pxo1"]
286374
preprocessed_data["denom"] = 1 - preprocessed_data["pxo2"]
@@ -309,13 +397,6 @@ def estimate_hazard_ratio(self):
309397
f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout'], ['id', 'time', 'fault_time', 'tin', 'tout']]}"
310398
)
311399

312-
preprocessed_data.pop("old_id")
313-
assert (
314-
not np.isinf(preprocessed_data[[col for col in preprocessed_data if preprocessed_data.dtypes[col] != bool]])
315-
.any()
316-
.any()
317-
), "Infinity not allowed."
318-
319400
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
320401
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
321402
print("Estimate the KM graph and IPCW hazard ratio using Cox regression.")
@@ -330,6 +411,7 @@ def estimate_hazard_ratio(self):
330411
formula="trtrand",
331412
entry_col="tin",
332413
)
414+
print("Estimated")
333415

334416
ci_low, ci_high = [np.exp(cox_ph.confidence_intervals_)[col] for col in cox_ph.confidence_intervals_.columns]
335417

0 commit comments

Comments
 (0)