Skip to content

Commit 190b239

Browse files
committed
Maybe fixed the estimation bug (one test will fail)
1 parent 7613e6b commit 190b239

File tree

1 file changed

+107
-31
lines changed

1 file changed

+107
-31
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
debug_id = "data-50/batch_run_16/00221634_10.csv"
17+
1618

1719
class IPCWEstimator(Estimator):
1820
"""
@@ -73,10 +75,12 @@ def __init__(
7375
self.fit_bltd_switch_formula = fit_bltd_switch_formula
7476
self.eligibility = eligibility
7577
self.df = df
78+
7679
if total_time is None:
77-
self.total_time = (
80+
total_time = (
7881
max(len(self.control_strategy), len(self.treatment_strategy)) + 1
7982
) * self.timesteps_per_observation
83+
self.total_time = total_time
8084
self.preprocess_data()
8185

8286
def add_modelling_assumptions(self):
@@ -92,8 +96,23 @@ def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligib
9296
:param strategy_followed - the strategy followed by the individual
9397
:param eligible - binary sequence represnting the eligibility of the individual at each time step
9498
"""
95-
strategy_assigned = [1] + strategy_assigned + [1]
96-
strategy_followed = [1] + strategy_followed + [1]
99+
100+
strategy_assigned = {t: (var, val) for t, var, val in strategy_assigned}
101+
strategy_followed = {t: (var, val) for t, var, val in strategy_followed}
102+
103+
# fill in the gaps
104+
for time in eligible.index:
105+
if time not in strategy_assigned:
106+
strategy_assigned[time] = (-1, -1)
107+
if time not in strategy_followed:
108+
strategy_followed[time] = (-1, -1)
109+
110+
strategy_assigned = sorted(
111+
[(t, var, val) for t, (var, val) in strategy_assigned.items() if t in eligible.index]
112+
)
113+
strategy_followed = sorted(
114+
[(t, var, val) for t, (var, val) in strategy_followed.items() if t in eligible.index]
115+
)
97116

98117
mask = (
99118
pd.Series(strategy_assigned, index=eligible.index) != pd.Series(strategy_followed, index=eligible.index)
@@ -110,25 +129,25 @@ def setup_xo_t_do(self, strategy_assigned: list, strategy_followed: list, eligib
110129
def setup_fault_t_do(self, individual: pd.DataFrame):
111130
"""
112131
Return a binary sequence with each bit representing whether the current
113-
index is the time point at which the event of interest (i.e. a fault)
114-
occurred.
132+
index is the time point at which the event of interest (i.e. a fault) occurred.
133+
134+
N.B. This is rounded _up_ to the nearest multiple of `self.timesteps_per_observation`.
135+
That is, if the fault occurs at time 22, and `self.timesteps_per_observation == 5`, then
136+
`fault_t_do` will be 25.
115137
"""
138+
116139
fault = individual[~individual[self.status_column]]
117-
fault_t_do = pd.Series(np.zeros(len(individual)), index=individual.index)
140+
individual["fault_t_do"] = 0
118141

119142
if not fault.empty:
120-
fault_time = individual["time"].loc[fault.index[0]]
121-
# Ceiling to nearest observation point
122-
fault_time = ceil(fault_time / self.timesteps_per_observation) * self.timesteps_per_observation
123-
# Set the correct observation point to be the fault time of doing (fault_t_do)
124-
observations = individual.loc[
125-
(individual["time"] % self.timesteps_per_observation == 0) & (individual["time"] < fault_time)
126-
]
127-
if not observations.empty:
128-
fault_t_do.loc[observations.index[0]] = 1
129-
assert sum(fault_t_do) <= 1, f"Multiple fault times for\n{individual}"
143+
time_fault_observed = (
144+
max(0, ceil(fault["time"].min() / self.timesteps_per_observation) - 1)
145+
) * self.timesteps_per_observation
146+
individual.loc[individual["time"] == time_fault_observed, "fault_t_do"] = 1
147+
148+
assert sum(individual["fault_t_do"]) <= 1, f"Multiple fault times for\n{individual}"
130149

131-
return pd.DataFrame({"fault_t_do": fault_t_do})
150+
return pd.DataFrame({"fault_t_do": individual["fault_t_do"]})
132151

133152
def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.001):
134153
"""
@@ -138,24 +157,40 @@ def setup_fault_time(self, individual: pd.DataFrame, perturbation: float = -0.00
138157
fault_time = (
139158
individual["time"].loc[fault.index[0]]
140159
if not fault.empty
141-
else (individual["time"].max() + self.timesteps_per_observation)
160+
else (self.total_time + self.timesteps_per_observation)
161+
)
162+
return pd.DataFrame(
163+
{
164+
"fault_time": np.repeat(fault_time + perturbation, len(individual)),
165+
}
142166
)
143-
return pd.DataFrame({"fault_time": np.repeat(fault_time + perturbation, len(individual))})
144167

145168
def preprocess_data(self):
146169
"""
147170
Set up the treatment-specific columns in the data that are needed to estimate the hazard ratio.
148171
"""
172+
149173
self.df["trtrand"] = None # treatment/control arm
150174
self.df["xo_t_do"] = None # did the individual deviate from the treatment of interest here?
151175
self.df["eligible"] = self.df.eval(self.eligibility) if self.eligibility is not None else True
152176

153177
# when did a fault occur?
154-
self.df["fault_time"] = self.df.groupby("id")[[self.status_column, "time"]].apply(self.setup_fault_time).values
155-
self.df["fault_t_do"] = (
156-
self.df.groupby("id")[["id", "time", self.status_column]].apply(self.setup_fault_t_do).values
178+
fault_time_df = self.df.groupby("id", sort=False)[[self.status_column, "time", "id"]].apply(
179+
self.setup_fault_time
180+
)
181+
182+
assert len(fault_time_df) == len(self.df), "Fault times error"
183+
self.df["fault_time"] = fault_time_df["fault_time"].values
184+
185+
assert (
186+
self.df.groupby("id", sort=False).apply(lambda x: len(set(x["fault_time"])) == 1).all()
187+
), f"Each individual must have a unique fault time."
188+
189+
fault_t_do_df = self.df.groupby("id", sort=False)[["id", "time", self.status_column]].apply(
190+
self.setup_fault_t_do
157191
)
158-
assert not pd.isnull(self.df["fault_time"]).any()
192+
assert len(fault_t_do_df) == len(self.df), "Fault t_do error"
193+
self.df["fault_t_do"] = fault_t_do_df["fault_t_do"].values
159194

160195
living_runs = self.df.query("fault_time > 0").loc[
161196
(self.df["time"] % self.timesteps_per_observation == 0) & (self.df["time"] <= self.total_time)
@@ -164,34 +199,74 @@ def preprocess_data(self):
164199
individuals = []
165200
new_id = 0
166201
logging.debug(" Preprocessing groups")
167-
for _, individual in living_runs.groupby("id"):
202+
203+
for id, individual in living_runs.groupby("id", sort=False):
168204
assert sum(individual["fault_t_do"]) <= 1, (
169205
f"Error initialising fault_t_do for individual\n"
170-
f"{individual[['id', 'time', 'fault_time', 'fault_t_do']]}\n"
171-
"with fault at {individual.fault_time.iloc[0]}"
206+
f"{individual[['id', 'time', self.status_column, 'fault_time', 'fault_t_do']]}\n"
207+
f"with fault at {individual.fault_time.iloc[0]}"
172208
)
173209

174210
strategy_followed = [
175211
[t, var, individual.loc[individual["time"] == t, var].values[0]]
176212
for t, var, val in self.treatment_strategy
213+
if t in individual["time"].values
177214
]
178215

216+
# print("CONTROL STRATEGY")
217+
# print(self.control_strategy)
218+
#
219+
# print("TREATMENT STRATEGY")
220+
# print(self.treatment_strategy)
221+
#
222+
# print()
223+
179224
# Control flow:
180225
# Individuals that start off in both arms, need cloning (hence incrementing the ID within the if statement)
181226
# Individuals that don't start off in either arm are left out
182227
for inx, strategy_assigned in [(0, self.control_strategy), (1, self.treatment_strategy)]:
183-
if strategy_assigned[0] == strategy_followed[0] and individual.eligible.iloc[0]:
228+
# print("STRATEGY", inx)
229+
# print("strategy_assigned")
230+
# print(strategy_assigned)
231+
# print("strategy_followed")
232+
# print(strategy_followed)
233+
# print(
234+
# "OK?",
235+
# (
236+
# len(strategy_followed) > 0,
237+
# strategy_assigned[0] == strategy_followed[0],
238+
# individual.eligible.iloc[0],
239+
# ),
240+
# )
241+
if (
242+
len(strategy_followed) > 0
243+
and strategy_assigned[0] == strategy_followed[0]
244+
and individual.eligible.iloc[0]
245+
):
246+
individual["old_id"] = individual["id"]
184247
individual["id"] = new_id
185248
new_id += 1
186249
individual["trtrand"] = inx
187250
individual["xo_t_do"] = self.setup_xo_t_do(
188251
strategy_assigned, strategy_followed, individual["eligible"]
189252
)
190-
individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
253+
# individuals.append(individual.loc[individual["time"] <= individual["fault_time"]].copy())
254+
individuals.append(
255+
individual.loc[
256+
individual["time"]
257+
< ceil(individual["fault_time"].iloc[0] / self.timesteps_per_observation)
258+
* self.timesteps_per_observation
259+
].copy()
260+
)
191261
if len(individuals) == 0:
192262
raise ValueError("No individuals followed either strategy.")
193-
194263
self.df = pd.concat(individuals)
264+
self.df.to_csv("/tmp/test.csv")
265+
266+
if len(self.df.loc[self.df["trtrand"] == 0]) == 0:
267+
raise ValueError(f"No individuals followed the control strategy {self.control_strategy}")
268+
if len(self.df.loc[self.df["trtrand"] == 1]) == 0:
269+
raise ValueError(f"No individuals followed the treatment strategy {self.treatment_strategy}")
195270

196271
def estimate_hazard_ratio(self):
197272
"""
@@ -201,7 +276,7 @@ def estimate_hazard_ratio(self):
201276
if self.df["fault_t_do"].sum() == 0:
202277
raise ValueError("No recorded faults")
203278

204-
preprocessed_data = self.df.loc[self.df["xo_t_do"] == 0].copy()
279+
preprocessed_data = self.df.copy()
205280

206281
# Use logistic regression to predict switching given baseline covariates
207282
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=self.df).fit()
@@ -242,7 +317,8 @@ def estimate_hazard_ratio(self):
242317
).min(axis=1)
243318

244319
assert (preprocessed_data["tin"] <= preprocessed_data["tout"]).all(), (
245-
f"Left before joining\n" f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout']]}"
320+
f"Left before joining\n"
321+
f"{preprocessed_data.loc[preprocessed_data['tin'] >= preprocessed_data['tout'], ['id', 'time', 'fault_time', 'tin', 'tout']]}"
246322
)
247323

248324
# IPCW step 4: Use these weights in a weighted analysis of the outcome model

0 commit comments

Comments
 (0)