Skip to content

Commit 72724eb

Browse files
committed
Seems to be working
1 parent 190b239 commit 72724eb

File tree

1 file changed

+23
-27
lines changed

1 file changed

+23
-27
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from math import ceil
55
from typing import Any
6+
from tqdm import tqdm
67

78
import numpy as np
89
import pandas as pd
@@ -74,14 +75,16 @@ def __init__(
7475
self.fit_bl_switch_formula = fit_bl_switch_formula
7576
self.fit_bltd_switch_formula = fit_bltd_switch_formula
7677
self.eligibility = eligibility
77-
self.df = df
78+
self.df = df.sort_values(["id", "time"])
7879

7980
if total_time is None:
8081
total_time = (
8182
max(len(self.control_strategy), len(self.treatment_strategy)) + 1
8283
) * self.timesteps_per_observation
8384
self.total_time = total_time
85+
print("PREPROCESSING")
8486
self.preprocess_data()
87+
print("PREPROCESSED")
8588

8689
def add_modelling_assumptions(self):
8790
self.modelling_assumptions.append("The variables in the data vary over time.")
@@ -200,7 +203,7 @@ def preprocess_data(self):
200203
new_id = 0
201204
logging.debug(" Preprocessing groups")
202205

203-
for id, individual in living_runs.groupby("id", sort=False):
206+
for id, individual in tqdm(living_runs.groupby("id", sort=False)):
204207
assert sum(individual["fault_t_do"]) <= 1, (
205208
f"Error initialising fault_t_do for individual\n"
206209
f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
@@ -213,31 +216,10 @@ def preprocess_data(self):
213216
if t in individual["time"].values
214217
]
215218

216-
# print("CONTROL STRATEGY")
217-
# print(self.control_strategy)
218-
#
219-
# print("TREATMENT STRATEGY")
220-
# print(self.treatment_strategy)
221-
#
222-
# print()
223-
224219
# Control flow:
225220
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
226221
# Individuals that don't start off in either arm are left out
227222
for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
228-
# print("STRATEGY", inx)
229-
# print("strategy_assigned")
230-
# print(strategy_assigned)
231-
# print("strategy_followed")
232-
# print(strategy_followed)
233-
# print(
234-
# "OK?",
235-
# (
236-
# len(strategy_followed) > 0,
237-
# strategy_assigned[0] == strategy_followed[0],
238-
# individual.eligible.iloc[0],
239-
# ),
240-
# )
241223
if (
242224
len(strategy_followed) > 0
243225
and strategy_assigned[0] == strategy_followed[0]
@@ -250,7 +232,6 @@ def preprocess_data(self):
250232
individual["xo_t_do"] = self.setup_xo_t_do(
251233
strategy_assigned, strategy_followed, individual["eligible"]
252234
)
253-
# individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
254235
individuals.append(
255236
individual.loc[
256237
individual["time"]
@@ -261,12 +242,12 @@ def preprocess_data(self):
261242
if len(individuals) == 0:
262243
raise ValueError("No individuals followed either strategy.")
263244
self.df = pd.concat(individuals)
264-
self.df.to_csv("/tmp/test.csv")
245+
print(len(individuals), "individuals")
265246

266247
if len(self.df.loc[self.df["trtrand"] == 0]) == 0:
267-
raise ValueError(f"No individuals followed the control strategy {self.control_strategy}")
248+
raise ValueError(f"No individuals began the control strategy {self.control_strategy}")
268249
if len(self.df.loc[self.df["trtrand"] == 1]) == 0:
269-
raise ValueError(f"No individuals followed the treatment strategy {self.treatment_strategy}")
250+
raise ValueError(f"No individuals began the treatment strategy {self.treatment_strategy}")
270251

271252
def estimate_hazard_ratio(self):
272253
"""
@@ -279,20 +260,27 @@ def estimate_hazard_ratio(self):
279260
preprocessed_data = self.df.copy()
280261

281262
# Use logistic regression to predict switching given baseline covariates
263+
print("Use logistic regression to predict switching given baseline covariates")
282264
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=self.df).fit()
283265

284266
preprocessed_data["pxo1"] = fit_bl_switch.predict(preprocessed_data)
285267

286268
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
269+
print("Use logistic regression to predict switching given baseline and time-updated covariates (model S12)")
287270
fit_bltd_switch = smf.logit(
288271
self.fit_bltd_switch_formula,
289272
data=self.df,
290273
).fit()
291274

292275
preprocessed_data["pxo2"] = fit_bltd_switch.predict(preprocessed_data)
276+
if (preprocessed_data["pxo2"] == 1).any():
277+
raise ValueError(
278+
"Probability of switching given baseline and time-varying confounders (pxo2) cannot be one."
279+
)
293280

294281
# IPCW step 3: For each individual at each time, compute the inverse probability of remaining uncensored
295282
# Estimate the probabilities of remaining ‘un-switched’ and hence the weights
283+
print("Estimate the probabilities of remaining ‘un-switched’ and hence the weights")
296284

297285
preprocessed_data["num"] = 1 - preprocessed_data["pxo1"]
298286
preprocessed_data["denom"] = 1 - preprocessed_data["pxo2"]
@@ -321,8 +309,16 @@ def estimate_hazard_ratio(self):
321309
f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout'], ['id', 'time', 'fault_time', 'tin', 'tout']]}"
322310
)
323311

312+
preprocessed_data.pop("old_id")
313+
assert (
314+
not np.isinf(preprocessed_data[[col for col in preprocessed_data if preprocessed_data.dtypes[col] != bool]])
315+
.any()
316+
.any()
317+
), "Infinity not allowed."
318+
324319
# IPCW step 4: Use these weights in a weighted analysis of the outcome model
325320
# Estimate the KM graph and IPCW hazard ratio using Cox regression.
321+
print("Estimate the KM graph and IPCW hazard ratio using Cox regression.")
326322
cox_ph = CoxPHFitter()
327323
cox_ph.fit(
328324
df=preprocessed_data,

0 commit comments

Comments
 (0)