Skip to content

Commit f505dd7

Browse files
committed
Optimisation and bug fixes
1 parent 9b5f075 commit f505dd7

File tree

1 file changed

+44
-94
lines changed

1 file changed

+44
-94
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 44 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -83,45 +83,12 @@ def __init__(
8383
max(len(self.control_strategy), len(self.treatment_strategy)) + 1
8484
) * self.timesteps_per_observation
8585
self.total_time = total_time
86-
print("PREPROCESSING")
8786
self.preprocess_data()
88-
print("PREPROCESSED")
8987

9088
def add_modelling_assumptions(self):
9189
self.modelling_assumptions.append("The variables in the data vary over time.")
9290

93-
def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligible: pd.Series, time: pd.Series):
94-
"""
95-
Return a binary sequence with each bit representing whether the current
96-
index is the time point at which the individual diverted from the
97-
assigned treatment strategy (and thus should be censored).
98-
99-
:param strategy_assigned - the assigned treatment strategy
100-
:param strategy_followed - the strategy followed by the individual
101-
:param eligible - binary sequence represnting the eligibility of the individual at each time step
102-
:param time - The sequence of time steps
103-
"""
104-
105-
default = {t: (-1, -1) for t in time.values}
106-
strategy_assigned = default | {t: (var, val) for t, var, val in strategy_assigned}
107-
strategy_followed = default | {t: (var, val) for t, var, val in strategy_followed}
108-
109-
strategy_assigned = sorted([(t, var, val) for t, (var, val) in strategy_assigned.items() if t in time.values])
110-
strategy_followed = sorted([(t, var, val) for t, (var, val) in strategy_followed.items() if t in time.values])
111-
112-
mask = (
113-
pd.Series(strategy_assigned, index=eligible.index) != pd.Series(strategy_followed, index=eligible.index)
114-
).astype("boolean")
115-
mask = mask | ~eligible
116-
mask.reset_index(inplace=True, drop=True)
117-
false = mask.loc[mask]
118-
if false.empty:
119-
return np.zeros(len(mask))
120-
mask = (mask * 1).tolist()
121-
cutoff = false.index[0] + 1
122-
return mask[:cutoff] + ([None] * (len(mask) - cutoff))
123-
124-
def setup_xo_t_do_2(self, individual: pd.DataFrame, strategy_assigned: list):
91+
def setup_xo_t_do(self, individual: pd.DataFrame, strategy_assigned: list):
12592
"""
12693
Return a binary sequence with each bit representing whether the current
12794
index is the time point at which the individual diverted from the
@@ -242,95 +209,75 @@ def preprocess_data(self):
242209

243210
logging.debug(" Preprocessing groups")
244211

245-
# new
246-
ctrl_time, ctrl_var, ctrl_val = self.control_strategy[0]
212+
ctrl_time_0, ctrl_var_0, ctrl_val_0 = self.control_strategy[0]
213+
ctrl_time, ctrl_var, ctrl_val = min(
214+
set(map(tuple, self.control_strategy)).difference(map(tuple, self.treatment_strategy))
215+
)
247216
control_group = (
248217
living_runs.groupby("id", sort=False)
249218
.filter(lambda gp: len(gp.loc[(gp["time"] == ctrl_time) & (gp[ctrl_var] == ctrl_val)]) > 0)
219+
.groupby("id", sort=False)
220+
.filter(lambda gp: len(gp.loc[(gp["time"] == ctrl_time_0) & (gp[ctrl_var_0] == ctrl_val_0)]) > 0)
250221
.copy()
251222
)
252223
control_group["trtrand"] = 0
253224
ctrl_xo_t_do_df = control_group.groupby("id", sort=False).apply(
254-
self.setup_xo_t_do_2, strategy_assigned=self.control_strategy
225+
self.setup_xo_t_do, strategy_assigned=self.control_strategy
255226
)
256227
control_group["xo_t_do"] = ctrl_xo_t_do_df["xo_t_do"].values
257228
control_group["old_id"] = control_group["id"]
258229
# control_group["id"] = ctrl_xo_t_do_df["id"].values
259230
control_group["id"] = [f"c-{id}" for id in control_group["id"]]
260231
assert not control_group["id"].isnull().any(), "Null control IDs"
261232

262-
trt_time, trt_var, trt_val = self.treatment_strategy[0]
233+
trt_time_0, trt_var_0, trt_val_0 = self.treatment_strategy[0]
234+
trt_time, trt_var, trt_val = min(
235+
set(map(tuple, self.treatment_strategy)).difference(map(tuple, self.control_strategy))
236+
)
263237
treatment_group = (
264238
living_runs.groupby("id", sort=False)
265239
.filter(lambda gp: len(gp.loc[(gp["time"] == trt_time) & (gp[trt_var] == trt_val)]) > 0)
240+
.groupby("id", sort=False)
241+
.filter(lambda gp: len(gp.loc[(gp["time"] == trt_time_0) & (gp[trt_var_0] == trt_val_0)]) > 0)
266242
.copy()
267243
)
268244
treatment_group["trtrand"] = 1
269245
trt_xo_t_do_df = treatment_group.groupby("id", sort=False).apply(
270-
self.setup_xo_t_do_2, strategy_assigned=self.treatment_strategy
246+
self.setup_xo_t_do, strategy_assigned=self.treatment_strategy
271247
)
272248
treatment_group["xo_t_do"] = trt_xo_t_do_df["xo_t_do"].values
273249
treatment_group["old_id"] = treatment_group["id"]
274250
# treatment_group["id"] = trt_xo_t_do_df["id"].values
275-
treatment_group["id"] = [f"c-{id}" for id in treatment_group["id"]]
251+
treatment_group["id"] = [f"t-{id}" for id in treatment_group["id"]]
276252
assert not treatment_group["id"].isnull().any(), "Null treatment IDs"
277253

254+
logger.debug(
255+
len(control_group.groupby("id")),
256+
"control individuals",
257+
len(treatment_group.groupby("id")),
258+
"treatment individuals",
259+
)
260+
278261
individuals = pd.concat([control_group, treatment_group])
279262
individuals = individuals.loc[
280-
individuals["time"]
281-
< ceil(individuals["fault_time"].iloc[0] / self.timesteps_per_observation) * self.timesteps_per_observation
282-
].copy()
283-
284-
individuals.sort_values(by=["old_id", "time"]).to_csv("/home/michael/tmp/vectorised_individuals.csv")
285-
# end new
286-
287-
# individuals = []
288-
#
289-
# for id, individual in tqdm(living_runs.groupby("id", sort=False)):
290-
# assert sum(individual["fault_t_do"]) <= 1, (
291-
# f"Error initialising fault_t_do for individual\n"
292-
# f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
293-
# f"with fault at {individual.fault_time.iloc[0]}"
294-
# )
295-
#
296-
# strategy_followed = [
297-
# [t, var, individual.loc[individual["time"] == t, var].values[0]]
298-
# for t, var, val in self.treatment_strategy
299-
# if t in individual["time"].values
300-
# ]
301-
#
302-
# # Control flow:
303-
# # Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
304-
# # Individuals that don't start off in either arm are left out
305-
# for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
306-
# if (
307-
# len(strategy_followed) > 0
308-
# and strategy_assigned[0] == strategy_followed[0]
309-
# and individual.eligible.iloc[0]
310-
# ):
311-
# individual["old_id"] = individual["id"]
312-
# individual["id"] = new_id
313-
# new_id += 1
314-
# individual["trtrand"] = inx
315-
# individual["xo_t_do"] = self.setup_xo_t_do(
316-
# strategy_assigned, strategy_followed, individual["eligible"], individual["time"]
317-
# )
318-
# individuals.append(
319-
# individual.loc[
320-
# individual["time"]
321-
# < ceil(individual["fault_time"].iloc[0] / self.timesteps_per_observation)
322-
# * self.timesteps_per_observation
323-
# ].copy()
324-
# )
325-
# self.df = pd.concat(individuals)
326-
# self.df.sort_values(by=["id", "time"]).to_csv("/home/michael/tmp/iterated_individuals.csv")
263+
(
264+
(
265+
individuals["time"]
266+
< ceil(individuals["fault_time"] / self.timesteps_per_observation) * self.timesteps_per_observation
267+
)
268+
& (~individuals["xo_t_do"].isnull())
269+
)
270+
]
271+
272+
individuals.sort_values(by=["id", "time"]).to_csv("/home/michael/tmp/vectorised_individuals.csv")
273+
327274
if len(individuals) == 0:
328275
raise ValueError("No individuals followed either strategy.")
329276
self.df = individuals.loc[
330277
individuals["time"]
331278
< ceil(individuals["fault_time"] / self.timesteps_per_observation) * self.timesteps_per_observation
332279
].reset_index()
333-
print(len(individuals), "individuals")
280+
logger.debug(len(individuals.groupby("id")), "individuals")
334281

335282
if len(self.df.loc[self.df["trtrand"] == 0]) == 0:
336283
raise ValueError(f"No individuals began the control strategy {self.control_strategy}")
@@ -348,13 +295,15 @@ def estimate_hazard_ratio(self):
348295
preprocessed_data = self.df.copy()
349296

350297
# Use logistic regression to predict switching given baseline covariates
351-
print("Use logistic regression to predict switching given baseline covariates")
298+
logger.debug("Use logistic regression to predict switching given baseline covariates")
352299
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=self.df).fit()
353300

354301
preprocessed_data["pxo1"] = fit_bl_switch.predict(preprocessed_data)
355302

356303
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
357-
print("Use logistic regression to predict switching given baseline and time-updated covariates (model S12)")
304+
logger.debug(
305+
"Use logistic regression to predict switching given baseline and time-updated covariates (model S12)"
306+
)
358307
fit_bltd_switch = smf.logit(
359308
self.fit_bltd_switch_formula,
360309
data=self.df,
@@ -368,7 +317,7 @@ def estimate_hazard_ratio(self):
368317

369318
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
370319
# Estimate the probabilities of remaining 'un-switched' and hence the weights
371-
print("Estimate the probabilities of remaining 'un-switched' and hence the weights")
320+
logger.debug("Estimate the probabilities of remaining 'un-switched' and hence the weights")
372321

373322
preprocessed_data["num"] = 1 - preprocessed_data["pxo1"]
374323
preprocessed_data["denom"] = 1 - preprocessed_data["pxo2"]
@@ -397,10 +346,12 @@ def estimate_hazard_ratio(self):
397346
f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout'], ['id', 'time', 'fault_time', 'tin', 'tout']]}"
398347
)
399348

349+
preprocessed_data.to_csv("/home/michael/tmp/preprocessed_data.csv")
350+
400351
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
401352
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
402-
print("Estimate the KM graph and IPCW hazard ratio using Cox regression.")
403-
cox_ph = CoxPHFitter()
353+
logger.debug("Estimate the KM graph and IPCW hazard ratio using Cox regression.")
354+
cox_ph = CoxPHFitter(penalizer=0.2, alpha=self.alpha)
404355
cox_ph.fit(
405356
df=preprocessed_data,
406357
duration_col="tout",
@@ -411,7 +362,6 @@ def estimate_hazard_ratio(self):
411362
formula="trtrand",
412363
entry_col="tin",
413364
)
414-
print("Estimated")
415365

416366
ci_low, ci_high = [np.exp(cox_ph.confidence_intervals_)[col] for col in cox_ph.confidence_intervals_.columns]
417367

0 commit comments

Comments
 (0)