Skip to content

Commit e6d46cd

Browse files
committed
tests pass again
1 parent 1d4afee commit e6d46cd

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

causal_testing/estimation/ipcw_estimator.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(
7575
self.fit_bltd_switch_formula = fit_bltd_switch_formula
7676
self.eligibility = eligibility
7777
self.df = df.sort_values(["id", "time"])
78+
self.len_control_group = None
79+
self.len_treatment_group = None
7880

7981
if total_time is None:
8082
total_time = (
@@ -249,13 +251,15 @@ def preprocess_data(self):
249251
treatment_group["id"] = [f"t-{id}" for id in treatment_group["id"]]
250252
assert not treatment_group["id"].isnull().any(), "Null treatment IDs"
251253

254+
premature_failures = living_runs.groupby("id", sort=False).filter(lambda gp: gp["time"].max() < trt_time)
252255
logger.debug(
253-
len(control_group.groupby("id")),
254-
"control individuals",
255-
len(treatment_group.groupby("id")),
256-
"treatment individuals",
256+
f"{len(control_group.groupby('id'))} control individuals "
257+
f"{len(treatment_group.groupby('id'))} treatment individuals "
258+
f"{len(premature_failures.groupby('id'))} premature failures"
257259
)
258260

261+
self.len_control_group = len(control_group.groupby("id"))
262+
self.len_treatment_group = len(treatment_group.groupby("id"))
259263
individuals = pd.concat([control_group, treatment_group])
260264
individuals = individuals.loc[
261265
(
@@ -274,7 +278,7 @@ def preprocess_data(self):
274278
individuals["time"]
275279
< np.ceil(individuals["fault_time"] / self.timesteps_per_observation) * self.timesteps_per_observation
276280
].reset_index()
277-
logger.debug(len(individuals.groupby("id")), "individuals")
281+
logger.debug(f"{len(individuals.groupby('id'))} individuals")
278282

279283
if len(self.df.loc[self.df["trtrand"] == 0]) == 0:
280284
raise ValueError(f"No individuals began the control strategy {self.control_strategy}")
@@ -293,20 +297,39 @@ def estimate_hazard_ratio(self):
293297

294298
# Use logistic regression to predict switching given baseline covariates
295299
logger.debug("Use logistic regression to predict switching given baseline covariates")
296-
fit_bl_switch = smf.logit(self.fit_bl_switch_formula, data=self.df).fit()
300+
fit_bl_switch_c = smf.logit(self.fit_bl_switch_formula, data=self.df.loc[self.df.trtrand == 0]).fit(
301+
method="bfgs"
302+
)
303+
fit_bl_switch_t = smf.logit(self.fit_bl_switch_formula, data=self.df.loc[self.df.trtrand == 1]).fit(
304+
method="bfgs"
305+
)
297306

298-
preprocessed_data["pxo1"] = fit_bl_switch.predict(preprocessed_data)
307+
preprocessed_data.loc[preprocessed_data["trtrand"] == 0, "pxo1"] = fit_bl_switch_c.predict(
308+
self.df.loc[self.df.trtrand == 0]
309+
)
310+
preprocessed_data.loc[preprocessed_data["trtrand"] == 1, "pxo1"] = fit_bl_switch_t.predict(
311+
self.df.loc[self.df.trtrand == 1]
312+
)
299313

300314
# Use logistic regression to predict switching given baseline and time-updated covariates (model S12)
301315
logger.debug(
302316
"Use logistic regression to predict switching given baseline and time-updated covariates (model S12)"
303317
)
304-
fit_bltd_switch = smf.logit(
318+
fit_bltd_switch_c = smf.logit(
305319
self.fit_bltd_switch_formula,
306-
data=self.df,
307-
).fit()
320+
data=self.df.loc[self.df.trtrand == 0],
321+
).fit(method="bfgs")
322+
fit_bltd_switch_t = smf.logit(
323+
self.fit_bltd_switch_formula,
324+
data=self.df.loc[self.df.trtrand == 1],
325+
).fit(method="bfgs")
308326

309-
preprocessed_data["pxo2"] = fit_bltd_switch.predict(preprocessed_data)
327+
preprocessed_data.loc[preprocessed_data["trtrand"] == 0, "pxo2"] = fit_bltd_switch_c.predict(
328+
self.df.loc[self.df.trtrand == 0]
329+
)
330+
preprocessed_data.loc[preprocessed_data["trtrand"] == 1, "pxo2"] = fit_bltd_switch_t.predict(
331+
self.df.loc[self.df.trtrand == 1]
332+
)
310333
if (preprocessed_data["pxo2"] == 1).any():
311334
raise ValueError(
312335
"Probability of switching given baseline and time-varying confounders (pxo2) cannot be one."

tests/estimation_tests/test_ipcw_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_estimate_hazard_ratio(self):
3333
eligibility=None,
3434
)
3535
estimate, intervals = estimation_model.estimate_hazard_ratio()
36-
self.assertEqual(round(estimate["trtrand"], 3), 1.936)
36+
self.assertEqual(round(estimate["trtrand"], 3), 1.351)
3737

3838
def test_invalid_treatment_strategies(self):
3939
timesteps_per_intervention = 1

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
import os
12
import unittest
23
from pathlib import Path
3-
from statistics import StatisticsError
44
import scipy
5-
import os
65
import pandas as pd
76

87
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
@@ -11,11 +10,9 @@
1110
from causal_testing.testing.causal_test_case import CausalTestCase
1211
from causal_testing.testing.causal_test_suite import CausalTestSuite
1312
from causal_testing.testing.causal_test_adequacy import DAGAdequacy
14-
from causal_testing.testing.causal_test_outcome import NoEffect, Positive, SomeEffect
13+
from causal_testing.testing.causal_test_outcome import NoEffect, SomeEffect
1514
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
16-
from causal_testing.specification.variable import Input, Output, Meta
1715
from causal_testing.specification.scenario import Scenario
18-
from causal_testing.specification.causal_specification import CausalSpecification
1916
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2017

2118

@@ -145,11 +142,11 @@ def test_data_adequacy_group_by(self):
145142
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, group_by="id")
146143
adequacy_metric.measure_adequacy()
147144
adequacy_dict = adequacy_metric.to_dict()
148-
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.336)
145+
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.857)
149146
adequacy_dict.pop("kurtosis")
150147
self.assertEqual(
151148
adequacy_dict,
152-
{"bootstrap_size": 100, "passing": 28, "successful": 95},
149+
{"bootstrap_size": 100, "passing": 32, "successful": 100},
153150
)
154151

155152
def test_dag_adequacy_dependent(self):

0 commit comments

Comments
 (0)