Skip to content

Commit 01589f6

Browse files
Revert changes to data checks in observational data collector
1 parent b084b9c commit 01589f6

File tree

4 files changed

+7
-14
lines changed

4 files changed

+7
-14
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ class DataCollector(ABC):
1818

1919
def __init__(self, scenario: Scenario):
2020
self.scenario = scenario
21-
self.data_checked = False
22-
2321
@abstractmethod
2422
def collect_data(self, **kwargs) -> pd.DataFrame:
2523
"""
@@ -158,5 +156,4 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
158156
for var_name, var in self.scenario.variables.items():
159157
if issubclass(var.datatype, Enum):
160158
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
161-
self.data_checked = True
162-
self.data = scenario_execution_data_df
159+
return scenario_execution_data_df

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,15 @@ def get_treatment_value(self):
8181
"""Return the treatment value of the treatment variable in this causal test case."""
8282
return self.treatment_value
8383

84-
def execute_test(self, estimator: type(Estimator), data_collector: ObservationalDataCollector, causal_specification: CausalSpecification) -> CausalTestResult:
84+
def execute_test(self, estimator: type(Estimator), data_collector: ObservationalDataCollector) -> CausalTestResult:
8585
"""Execute a causal test case and return the causal test result.
8686
8787
:param estimator: A reference to an Estimator class.
8888
:param causal_test_case: The CausalTestCase object to be tested
8989
:return causal_test_result: A CausalTestResult for the executed causal test case.
9090
"""
91-
if not data_collector.data_checked:
92-
data_collector.collect_data()
9391
if estimator.df is None:
94-
estimator.df = data_collector.data
92+
estimator.df = data_collector.collect_data()
9593
treatment_variable = self.treatment_variable
9694
treatments = treatment_variable.name
9795
outcome_variable = self.outcome_variable

examples/poisson-line-process/example_poisson_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def causal_test_intensity_num_shapes(
119119
)
120120

121121
# 9. Execute the test
122-
causal_test_result = causal_test_case.execute_test(estimator, data_collector, causal_specification)
122+
causal_test_result = causal_test_case.execute_test(estimator, data_collector)
123123

124124
return causal_test_result
125125

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,13 @@ def test_not_all_variables_in_data(self):
4444
def test_all_variables_in_data(self):
4545
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
4646
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
47-
observational_data_collector.collect_data(index_col=0)
48-
df = observational_data_collector.data
47+
df = observational_data_collector.collect_data(index_col=0)
4948
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
5049

5150
def test_data_constraints(self):
5251
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
5352
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
54-
observational_data_collector.collect_data(index_col=0)
55-
df = observational_data_collector.data
53+
df = observational_data_collector.collect_data(index_col=0)
5654
expected = self.observational_df.loc[[2, 3]]
5755
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
5856

@@ -64,7 +62,7 @@ def populate_m(data):
6462
scenario = Scenario({self.X1, meta})
6563
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df)
6664
observational_data_collector.collect_data()
67-
data = observational_data_collector.data
65+
data = observational_data_collector.collect_data()
6866
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6967

7068
def tearDown(self) -> None:

0 commit comments

Comments
 (0)