Skip to content

Commit 41ea949

Browse files
committed
pytest
1 parent c20899d commit 41ea949

File tree

7 files changed

+103
-8
lines changed

7 files changed

+103
-8
lines changed

tests/resources/data/data.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
index,test_input,test_input_no_dist,test_output
2-
0,1,1,2
2+
0,1.0,1.0,2.0
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
index,test_input,test_input_no_dist,test_output,test_meta
2-
0,1,1,2,3
2+
0,1.0,1.0,2.0,3
File renamed without changes.
File renamed without changes.

tests/data/temporal_data.csv renamed to tests/resources/data/temporal_data.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@ t,outcome,id,time
5959
1,0,11,2
6060
1,0,11,3
6161
1,0,11,4
62+
0,1,12,0
63+
1,1,12,1
64+
0,1,12,2
65+
1,0,12,3
66+
0,0,12,4

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
from statistics import StatisticsError
44
import scipy
55
import os
6+
import pandas as pd
67

7-
from causal_testing.testing.estimators import LinearRegressionEstimator
8+
from causal_testing.testing.estimators import LinearRegressionEstimator, IPCWEstimator
89
from causal_testing.testing.base_test_case import BaseTestCase
910
from causal_testing.testing.causal_test_case import CausalTestCase
1011
from causal_testing.testing.causal_test_suite import CausalTestSuite
1112
from causal_testing.testing.causal_test_adequacy import DAGAdequacy
12-
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
13+
from causal_testing.testing.causal_test_outcome import NoEffect, Positive, SomeEffect
1314
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1415
from causal_testing.specification.variable import Input, Output, Meta
1516
from causal_testing.specification.scenario import Scenario
1617
from causal_testing.specification.causal_specification import CausalSpecification
18+
from causal_testing.specification.capabilities import TreatmentSequence
19+
from causal_testing.testing.causal_test_adequacy import DataAdequacy
1720

1821

1922
class TestCausalTestAdequacy(unittest.TestCase):
@@ -106,6 +109,48 @@ def test_data_adequacy_cateogorical(self):
106109
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100, "successful": 100},
107110
)
108111

112+
def test_data_adequacy_group_by(self):
113+
timesteps_per_intervention = 1
114+
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
115+
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
116+
outcome = "outcome"
117+
fit_bl_switch_formula = "xo_t_do ~ time"
118+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
119+
df["ok"] = df["outcome"] == 1
120+
estimation_model = IPCWEstimator(
121+
df,
122+
timesteps_per_intervention,
123+
control_strategy,
124+
treatment_strategy,
125+
outcome,
126+
"ok",
127+
fit_bl_switch_formula=fit_bl_switch_formula,
128+
fit_bltd_switch_formula=fit_bl_switch_formula,
129+
eligibility=None,
130+
)
131+
base_test_case = BaseTestCase(
132+
treatment_variable=control_strategy,
133+
outcome_variable=outcome,
134+
effect="temporal",
135+
)
136+
137+
causal_test_case = CausalTestCase(
138+
base_test_case=base_test_case,
139+
expected_causal_effect=SomeEffect(),
140+
control_value=control_strategy,
141+
treatment_value=treatment_strategy,
142+
estimate_type="hazard_ratio",
143+
)
144+
causal_test_result = causal_test_case.execute_test(estimation_model, None)
145+
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, group_by="id")
146+
adequacy_metric.measure_adequacy()
147+
causal_test_result.adequacy = adequacy_metric
148+
print(causal_test_result.adequacy.to_dict())
149+
self.assertEqual(
150+
causal_test_result.adequacy.to_dict(),
151+
{"kurtosis": {"trtrand": 0.0}, "bootstrap_size": 100, "passing": 0, "successful": 95},
152+
)
153+
109154
def test_dag_adequacy_dependent(self):
110155
base_test_case = BaseTestCase(
111156
treatment_variable="test_input",

tests/testing_tests/test_estimators.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def load_nhefs_df():
3535
"""Get the NHEFS data from chapter 12 and put into a dataframe. NHEFS = National Health and Nutrition Examination
3636
Survey Data I Epidemiological Follow-up Study."""
3737

38-
nhefs_df = pd.read_csv("tests/data/nhefs.csv")
38+
nhefs_df = pd.read_csv("tests/resources/data/nhefs.csv")
3939
nhefs_df["one"] = 1
4040
nhefs_df["zero"] = 0
4141
edu_dummies = pd.get_dummies(nhefs_df.education, prefix="edu")
@@ -79,7 +79,7 @@ class TestLogisticRegressionEstimator(unittest.TestCase):
7979

8080
@classmethod
8181
def setUpClass(cls) -> None:
82-
cls.scarf_df = pd.read_csv("tests/data/scarf_data.csv")
82+
cls.scarf_df = pd.read_csv("tests/resources/data/scarf_data.csv")
8383

8484
# Yes, this probably shouldn't be in here, but it uses the scarf data so it makes more sense to put it
8585
# here than duplicating the scarf data for a single test
@@ -446,7 +446,7 @@ def setUpClass(cls) -> None:
446446
df = pd.DataFrame({"X1": np.random.uniform(-1000, 1000, 1000), "X2": np.random.uniform(-1000, 1000, 1000)})
447447
df["Y"] = 2 * df["X1"] - 3 * df["X2"] + 2 * df["X1"] * df["X2"] + 10
448448
cls.df = df
449-
cls.scarf_df = pd.read_csv("tests/data/scarf_data.csv")
449+
cls.scarf_df = pd.read_csv("tests/resources/data/scarf_data.csv")
450450

451451
def test_X1_effect(self):
452452
"""When we fix the value of X2 to 0, the effect of X1 on Y should become ~2 (because X2 terms are cancelled)."""
@@ -485,7 +485,7 @@ def test_estimate_hazard_ratio(self):
485485
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
486486
outcome = "outcome"
487487
fit_bl_switch_formula = "xo_t_do ~ time"
488-
df = pd.read_csv("tests/data/temporal_data.csv")
488+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
489489
df["ok"] = df["outcome"] == 1
490490
estimation_model = IPCWEstimator(
491491
df,
@@ -500,3 +500,48 @@ def test_estimate_hazard_ratio(self):
500500
)
501501
estimate, intervals = estimation_model.estimate_hazard_ratio()
502502
self.assertEqual(estimate["trtrand"], 1.0)
503+
504+
def test_invalid_treatment_strategies(self):
505+
timesteps_per_intervention = 1
506+
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
507+
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
508+
outcome = "outcome"
509+
fit_bl_switch_formula = "xo_t_do ~ time"
510+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
511+
df["t"] = (["1", "0"] * len(df))[: len(df)]
512+
df["ok"] = df["outcome"] == 1
513+
with self.assertRaises(ValueError):
514+
estimation_model = IPCWEstimator(
515+
df,
516+
timesteps_per_intervention,
517+
control_strategy,
518+
treatment_strategy,
519+
outcome,
520+
"ok",
521+
fit_bl_switch_formula=fit_bl_switch_formula,
522+
fit_bltd_switch_formula=fit_bl_switch_formula,
523+
eligibility=None,
524+
)
525+
526+
def test_invalid_fault_t_do(self):
527+
timesteps_per_intervention = 1
528+
control_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 0), ("t", 0), ("t", 0)])
529+
treatment_strategy = TreatmentSequence(timesteps_per_intervention, [("t", 1), ("t", 1), ("t", 1)])
530+
outcome = "outcome"
531+
fit_bl_switch_formula = "xo_t_do ~ time"
532+
df = pd.read_csv("tests/resources/data/temporal_data.csv")
533+
df["ok"] = df["outcome"] == 1
534+
estimation_model = IPCWEstimator(
535+
df,
536+
timesteps_per_intervention,
537+
control_strategy,
538+
treatment_strategy,
539+
outcome,
540+
"ok",
541+
fit_bl_switch_formula=fit_bl_switch_formula,
542+
fit_bltd_switch_formula=fit_bl_switch_formula,
543+
eligibility=None,
544+
)
545+
estimation_model.df["fault_t_do"] = 0
546+
with self.assertRaises(ValueError):
547+
estimate, intervals = estimation_model.estimate_hazard_ratio()

0 commit comments

Comments
 (0)