Skip to content

Commit 150f942

Browse files
committed
changes
1 parent 99d0eb6 commit 150f942

File tree

2 files changed

+81
-170
lines changed

2 files changed

+81
-170
lines changed
Lines changed: 78 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,191 +1,100 @@
1-
import os
2-
import logging
1+
"""
2+
Causal testing framework example script.
33
4-
import pandas as pd
4+
This example demonstrates the use of a custom EmpiricalMeanEstimator
5+
for causal testing and runs causal tests using the framework.
6+
"""
57

6-
from causal_testing.specification.causal_dag import CausalDAG
7-
from causal_testing.specification.scenario import Scenario
8-
from causal_testing.specification.variable import Input, Output
9-
from causal_testing.specification.causal_specification import CausalSpecification
10-
from causal_testing.testing.causal_test_case import CausalTestCase
11-
from causal_testing.testing.causal_effect import ExactValue, Positive
12-
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
13-
from causal_testing.estimation.abstract_estimator import Estimator
14-
from causal_testing.testing.base_test_case import BaseTestCase
8+
import os
9+
from typing import Tuple, Optional, Any
1510

11+
from causal_testing.main import (
12+
CausalTestingFramework,
13+
CausalTestingPaths,
14+
setup_logging
15+
)
16+
from causal_testing.estimation.abstract_estimator import Estimator
1617

17-
logger = logging.getLogger(__name__)
18-
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
18+
setup_logging(verbose=True)
1919

2020

2121
class EmpiricalMeanEstimator(Estimator):
22+
"""
23+
Estimator that computes treatment effects using empirical means.
24+
25+
This estimator calculates the Average Treatment Effect (ATE) and risk ratio
26+
by directly comparing the means of treatment and control groups.
27+
"""
28+
2229
def add_modelling_assumptions(self):
30+
"""Add modeling assumptions for this estimator."""
31+
self.modelling_assumptions += (
32+
"The data must contain runs with the exact configuration of interest."
33+
)
34+
35+
def estimate_ate(self) -> Tuple[float, Optional[Any]]:
2336
"""
24-
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
25-
must hold if the resulting causal inference is to be considered valid.
26-
"""
27-
self.modelling_assumptions += "The data must contain runs with the exact configuration of interest."
37+
Estimate the Average Treatment Effect.
2838
29-
def estimate_ate(self) -> float:
30-
"""Estimate the outcomes under control and treatment.
31-
:return: The empirical average treatment effect.
39+
Returns:
40+
Tuple containing the ATE estimate and optional additional data.
3241
"""
33-
control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[
34-
self.base_test_case.outcome_variable.name
35-
].dropna()
36-
treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[
37-
self.base_test_case.outcome_variable.name
38-
].dropna()
42+
control_results = self.df.where(
43+
self.df[self.base_test_case.treatment_variable.name] == self.control_value
44+
)[self.base_test_case.outcome_variable.name].dropna()
45+
46+
treatment_results = self.df.where(
47+
self.df[self.base_test_case.treatment_variable.name] == self.treatment_value
48+
)[self.base_test_case.outcome_variable.name].dropna()
49+
3950
return treatment_results.mean() - control_results.mean(), None
4051

41-
def estimate_risk_ratio(self) -> float:
42-
"""Estimate the outcomes under control and treatment.
43-
:return: The empirical average treatment effect.
52+
def estimate_risk_ratio(self) -> Tuple[float, Optional[Any]]:
4453
"""
45-
control_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.control_value)[
46-
self.base_test_case.outcome_variable.name
47-
].dropna()
48-
treatment_results = self.df.where(self.df[self.base_test_case.treatment_variable.name] == self.treatment_value)[
49-
self.base_test_case.outcome_variable.name
50-
].dropna()
54+
Estimate the risk ratio.
55+
56+
Returns:
57+
Tuple containing the risk ratio estimate and optional additional data.
58+
"""
59+
control_results = self.df.where(
60+
self.df[self.base_test_case.treatment_variable.name] == self.control_value
61+
)[self.base_test_case.outcome_variable.name].dropna()
62+
63+
treatment_results = self.df.where(
64+
self.df[self.base_test_case.treatment_variable.name] == self.treatment_value
65+
)[self.base_test_case.outcome_variable.name].dropna()
66+
5167
return treatment_results.mean() / control_results.mean(), None
5268

5369

54-
# 1. Read in the Causal DAG
5570
ROOT = os.path.realpath(os.path.dirname(__file__))
56-
causal_dag = CausalDAG(f"{ROOT}/dag.dot")
57-
58-
# 2. Create variables
59-
width = Input("width", float)
60-
height = Input("height", float)
61-
intensity = Input("intensity", float)
62-
63-
num_lines_abs = Output("num_lines_abs", float)
64-
num_lines_unit = Output("num_lines_unit", float)
65-
num_shapes_abs = Output("num_shapes_abs", float)
66-
num_shapes_unit = Output("num_shapes_unit", float)
67-
68-
# 3. Create scenario
69-
scenario = Scenario(
70-
variables={
71-
width,
72-
height,
73-
intensity,
74-
num_lines_abs,
75-
num_lines_unit,
76-
num_shapes_abs,
77-
num_shapes_unit,
78-
}
79-
)
8071

81-
# 4. Construct a causal specification from the scenario and causal DAG
82-
causal_specification = CausalSpecification(scenario, causal_dag)
83-
84-
observational_data_path = f"{ROOT}/data/random/data_random_1000.csv"
85-
86-
87-
def test_poisson_intensity_num_shapes(save=False):
88-
intensity_num_shapes_results = []
89-
base_test_case = BaseTestCase(treatment_variable=intensity, outcome_variable=num_shapes_unit)
90-
observational_df = pd.read_csv(observational_data_path, index_col=0).astype(float)
91-
causal_test_cases = [
92-
(
93-
CausalTestCase(
94-
base_test_case=base_test_case,
95-
expected_causal_effect=ExactValue(4, atol=0.5),
96-
estimate_type="risk_ratio",
97-
estimator=EmpiricalMeanEstimator(
98-
base_test_case=base_test_case,
99-
treatment_value=treatment_value,
100-
control_value=control_value,
101-
adjustment_set=causal_specification.causal_dag.identification(base_test_case),
102-
df=pd.read_csv(f"{ROOT}/data/smt_100/data_smt_wh{wh}_100.csv", index_col=0).astype(float),
103-
effect_modifiers=None,
104-
alpha=0.05,
105-
query="",
106-
),
107-
),
108-
CausalTestCase(
109-
base_test_case=base_test_case,
110-
expected_causal_effect=ExactValue(4, atol=0.5),
111-
estimate_type="risk_ratio",
112-
estimator=LinearRegressionEstimator(
113-
base_test_case=base_test_case,
114-
treatment_value=treatment_value,
115-
control_value=control_value,
116-
adjustment_set=causal_specification.causal_dag.identification(base_test_case),
117-
df=observational_df,
118-
effect_modifiers=None,
119-
formula="num_shapes_unit ~ I(intensity ** 2) + intensity - 1",
120-
alpha=0.05,
121-
query="",
122-
),
123-
),
124-
)
125-
for control_value, treatment_value in [(1, 2), (2, 4), (4, 8), (8, 16)]
126-
for wh in range(1, 11)
127-
]
128-
129-
test_results = [(smt.execute_test(), observational.execute_test()) for smt, observational in causal_test_cases]
130-
131-
intensity_num_shapes_results += [
132-
{
133-
"width": obs_causal_test_result.estimator.control_value,
134-
"height": obs_causal_test_result.estimator.treatment_value,
135-
"control": obs_causal_test_result.estimator.control_value,
136-
"treatment": obs_causal_test_result.estimator.treatment_value,
137-
"smt_risk_ratio": smt_causal_test_result.test_value.value,
138-
"obs_risk_ratio": obs_causal_test_result.test_value.value[0],
139-
}
140-
for smt_causal_test_result, obs_causal_test_result in test_results
141-
]
142-
intensity_num_shapes_results = pd.DataFrame(intensity_num_shapes_results)
143-
if save:
144-
intensity_num_shapes_results.to_csv("intensity_num_shapes_results_random_1000.csv")
145-
logger.info("%s", intensity_num_shapes_results)
146-
147-
148-
def test_poisson_width_num_shapes(save=False):
149-
base_test_case = BaseTestCase(treatment_variable=width, outcome_variable=num_shapes_unit)
150-
df = pd.read_csv(observational_data_path, index_col=0).astype(float)
151-
causal_test_cases = [
152-
CausalTestCase(
153-
base_test_case=base_test_case,
154-
expected_causal_effect=Positive(),
155-
estimate_type="ate_calculated",
156-
effect_modifier_configuration={"intensity": i},
157-
estimator=LinearRegressionEstimator(
158-
base_test_case=base_test_case,
159-
treatment_value=w + 1.0,
160-
control_value=float(w),
161-
adjustment_set=causal_specification.causal_dag.identification(base_test_case),
162-
df=df,
163-
effect_modifiers={"intensity": i},
164-
formula="num_shapes_unit ~ width + I(intensity ** 2)+I(width ** -1)+intensity-1",
165-
alpha=0.05,
166-
),
167-
)
168-
for i in range(1, 17)
169-
for w in range(1, 10)
170-
]
171-
test_results = [test.execute_test() for test in causal_test_cases]
172-
width_num_shapes_results = [
173-
{
174-
"control": causal_test_result.estimator.control_value,
175-
"treatment": causal_test_result.estimator.treatment_value,
176-
"intensity": causal_test_result.effect_modifier_configuration["intensity"],
177-
"ate": causal_test_result.test_value.value[0],
178-
"ci_low": causal_test_result.confidence_intervals[0][0],
179-
"ci_high": causal_test_result.confidence_intervals[1][0],
180-
}
181-
for causal_test_result in test_results
182-
]
183-
width_num_shapes_results = pd.DataFrame(width_num_shapes_results)
184-
if save:
185-
width_num_shapes_results.to_csv("width_num_shapes_results_random_1000.csv")
186-
logger.info("%s", width_num_shapes_results)
72+
73+
def run_causal_tests() -> None:
74+
"""
75+
Run causal tests using the framework.
76+
77+
Sets up paths, initialises the framework, loads tests, runs them,
78+
and saves the results.
79+
"""
80+
dag_path = os.path.join(ROOT, "dag.dot")
81+
data_paths = [os.path.join(ROOT, "data/random/data_random_1000.csv")]
82+
test_config_path = os.path.join(ROOT, "causal_tests.json")
83+
output_path = os.path.join(ROOT, "causal_test_results.json")
84+
85+
paths = CausalTestingPaths(
86+
dag_path=dag_path,
87+
data_paths=data_paths,
88+
test_config_path=test_config_path,
89+
output_path=output_path
90+
)
91+
92+
framework = CausalTestingFramework(paths=paths, ignore_cycles=False)
93+
framework.setup()
94+
framework.load_tests()
95+
results = framework.run_tests()
96+
framework.save_results(results)
18797

18898

18999
if __name__ == "__main__":
190-
test_poisson_intensity_num_shapes(save=False)
191-
test_poisson_width_num_shapes(save=True)
100+
run_causal_tests()

examples/poisson-line-process/poisson_line_process.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,9 @@
797797
}
798798
],
799799
"source": [
800-
"width_results_df = pd.DataFrame(width_results)\n",
800+
"# Convert the results to a dataframe and print the first 10 rows\n",
801+
"\n",
802+
"width_results_df = pd.DataFrame(width_results) \n",
801803
"\n",
802804
"width_results_df.head(10)"
803805
]

0 commit comments

Comments
 (0)