Skip to content

Commit 97f6f07

Browse files
committed
test main
1 parent 7871296 commit 97f6f07

File tree

4 files changed

+130
-134
lines changed

4 files changed

+130
-134
lines changed

causal_testing/main.py

Lines changed: 61 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def setup(self) -> None:
103103
Set up the framework by loading DAG, runtime csv data, creating the scenario and causal specification.
104104
105105
:raises: FileNotFoundError if required files are missing
106-
:raises: Exception if setup process fails
107106
"""
108107

109108
logger.info("Setting up Causal Testing Framework...")
@@ -128,17 +127,11 @@ def setup(self) -> None:
128127
def load_dag(self) -> CausalDAG:
129128
"""
130129
Load the causal DAG from the specified file path.
131-
132-
:raises: Exception if DAG loading fails
133130
"""
134131
logger.info(f"Loading DAG from {self.paths.dag_path}")
135-
try:
136-
dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles)
137-
logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges")
138-
return dag
139-
except Exception as e:
140-
logger.error(f"Failed to load DAG: {str(e)}")
141-
raise
132+
dag = CausalDAG(str(self.paths.dag_path), ignore_cycles=self.ignore_cycles)
133+
logger.info(f"DAG loaded with {len(dag.graph.nodes)} nodes and {len(dag.graph.edges)} edges")
134+
return dag
142135

143136
def _read_dataframe(self, data_path):
144137
if str(data_path).endswith(".csv"):
@@ -152,33 +145,22 @@ def load_data(self, query: Optional[str] = None) -> pd.DataFrame:
152145
153146
:param query: Optional pandas query string to filter the loaded data
154147
:return: Combined pandas DataFrame containing all loaded and filtered data
155-
:raises: Exception if data loading or query application fails
156148
"""
157149
logger.info(f"Loading data from {len(self.paths.data_paths)} source(s)")
158150

159-
try:
160-
dfs = [self._read_dataframe(data_path) for data_path in self.paths.data_paths]
161-
data = pd.concat(dfs, axis=0, ignore_index=True)
162-
logger.info(f"Initial data shape: {data.shape}")
163-
164-
if query:
165-
try:
166-
logger.info(f"Attempting to apply query: '{query}'")
167-
data = data.query(query)
168-
except Exception as e:
169-
logger.error(f"Failed to apply query '{query}': {str(e)}")
170-
raise
151+
dfs = [self._read_dataframe(data_path) for data_path in self.paths.data_paths]
152+
data = pd.concat(dfs, axis=0, ignore_index=True)
153+
logger.info(f"Initial data shape: {data.shape}")
154+
155+
if query:
156+
logger.info(f"Attempting to apply query: '{query}'")
157+
data = data.query(query)
171158

172-
return data
173-
except Exception as e:
174-
logger.error(f"Failed to load data: {str(e)}")
175-
raise
159+
return data
176160

177161
def create_variables(self) -> None:
178162
"""
179163
Create variable objects from DAG nodes based on their connectivity.
180-
181-
182164
"""
183165
for node_name, node_data in self.dag.graph.nodes(data=True):
184166
if node_name not in self.data.columns and not node_data.get("hidden", False):
@@ -195,63 +177,25 @@ def create_variables(self) -> None:
195177
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)
196178

197179
def create_scenario_and_specification(self) -> None:
198-
"""Create scenario and causal specification objects from loaded data.
199-
200-
201-
:raises: ValueError if scenario constraints filter out all data points
202-
"""
180+
"""Create scenario and causal specification objects from loaded data."""
203181
# Create scenario
204182
all_variables = list(self.variables["inputs"].values()) + list(self.variables["outputs"].values())
205183
self.scenario = Scenario(variables=all_variables)
206184

207185
# Set up treatment variables
208186
self.scenario.setup_treatment_variables()
209187

210-
# Apply scenario constraints to data
211-
self.apply_scenario_constraints()
212-
213188
# Create causal specification
214189
self.causal_specification = CausalSpecification(scenario=self.scenario, causal_dag=self.dag)
215190

216-
def apply_scenario_constraints(self) -> None:
217-
"""
218-
Apply scenario constraints to the loaded data.
219-
220-
:raises: ValueError if all data points are filtered out by constraints
221-
"""
222-
if not self.scenario.constraints:
223-
logger.info("No scenario constraints to apply")
224-
return
225-
226-
original_rows = len(self.data)
227-
228-
# Apply each constraint directly as a query string
229-
for constraint in self.scenario.constraints:
230-
self.data = self.data.query(str(constraint))
231-
logger.debug(f"Applied constraint: {constraint}")
232-
233-
filtered_rows = len(self.data)
234-
if filtered_rows < original_rows:
235-
logger.info(f"Scenario constraints filtered data from {original_rows} to {filtered_rows} rows")
236-
237-
if filtered_rows == 0:
238-
raise ValueError("Scenario constraints filtered out all data points. Check your constraints and data.")
239-
240191
def load_tests(self) -> None:
241192
"""
242193
Load and prepare test configurations from file.
243-
244-
245-
:raises: Exception if test configuration loading fails
246194
"""
247195
logger.info(f"Loading test configurations from {self.paths.test_config_path}")
248196

249-
try:
250-
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
251-
test_configs = json.load(f)
252-
except Exception as e:
253-
logger.error(f"Failed to load test configurations: {str(e)}")
254-
raise
197+
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
198+
test_configs = json.load(f)
255199

256200
self.test_cases = self.create_test_cases(test_configs)
257201

@@ -400,63 +344,53 @@ def save_results(self, results: List[CausalTestResult]) -> None:
400344
"""Save test results to JSON file in the expected format."""
401345
logger.info(f"Saving results to {self.paths.output_path}")
402346

403-
try:
404-
# Load original test configs to preserve test metadata
405-
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
406-
test_configs = json.load(f)
407-
408-
# Combine test configs with their results
409-
json_results = []
410-
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
411-
# Handle effect estimate - could be a Series or other format
412-
effect_estimate = result.test_value.value
413-
if isinstance(effect_estimate, pd.Series):
414-
effect_estimate = effect_estimate.to_dict()
415-
416-
# Handle confidence intervals - convert to list if needed
417-
ci_low = result.ci_low()
418-
ci_high = result.ci_high()
419-
if isinstance(ci_low, pd.Series):
420-
ci_low = ci_low.tolist()
421-
if isinstance(ci_high, pd.Series):
422-
ci_high = ci_high.tolist()
423-
424-
# Determine if test failed based on expected vs actual effect
425-
test_passed = (
426-
test_case.expected_causal_effect.apply(result) if result.test_value.type != "Error" else False
427-
)
428-
429-
output = {
430-
"name": test_config["name"],
431-
"estimate_type": test_config["estimate_type"],
432-
"effect": test_config.get("effect", "direct"),
433-
"treatment_variable": test_config["treatment_variable"],
434-
"expected_effect": test_config["expected_effect"],
435-
"formula": test_config.get("formula"),
436-
"alpha": test_config.get("alpha", 0.05),
437-
"skip": test_config.get("skip", False),
438-
"passed": test_passed,
439-
"result": {
440-
"treatment": result.estimator.base_test_case.treatment_variable.name,
441-
"outcome": result.estimator.base_test_case.outcome_variable.name,
442-
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
443-
"effect_measure": result.test_value.type,
444-
"effect_estimate": effect_estimate,
445-
"ci_low": ci_low,
446-
"ci_high": ci_high,
447-
},
448-
}
449-
json_results.append(output)
450-
451-
# Save to file
452-
with open(self.paths.output_path, "w", encoding="utf-8") as f:
453-
json.dump(json_results, f, indent=2)
454-
455-
logger.info("Results saved successfully")
456-
457-
except Exception as e:
458-
logger.error(f"Failed to save results: {str(e)}")
459-
raise
347+
# Load original test configs to preserve test metadata
348+
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
349+
test_configs = json.load(f)
350+
351+
# Combine test configs with their results
352+
json_results = []
353+
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
354+
# Handle effect estimate - could be a Series or other format
355+
effect_estimate = result.test_value.value
356+
if isinstance(effect_estimate, pd.Series):
357+
effect_estimate = effect_estimate.to_dict()
358+
359+
# Handle confidence intervals - convert to list if needed
360+
ci_low = result.ci_low()
361+
ci_high = result.ci_high()
362+
363+
# Determine if test failed based on expected vs actual effect
364+
test_passed = test_case.expected_causal_effect.apply(result) if result.test_value.type != "Error" else False
365+
366+
output = {
367+
"name": test_config["name"],
368+
"estimate_type": test_config["estimate_type"],
369+
"effect": test_config.get("effect", "direct"),
370+
"treatment_variable": test_config["treatment_variable"],
371+
"expected_effect": test_config["expected_effect"],
372+
"formula": test_config.get("formula"),
373+
"alpha": test_config.get("alpha", 0.05),
374+
"skip": test_config.get("skip", False),
375+
"passed": test_passed,
376+
"result": {
377+
"treatment": result.estimator.base_test_case.treatment_variable.name,
378+
"outcome": result.estimator.base_test_case.outcome_variable.name,
379+
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
380+
"effect_measure": result.test_value.type,
381+
"effect_estimate": effect_estimate,
382+
"ci_low": ci_low,
383+
"ci_high": ci_high,
384+
},
385+
}
386+
json_results.append(output)
387+
388+
# Save to file
389+
with open(self.paths.output_path, "w", encoding="utf-8") as f:
390+
json.dump(json_results, f, indent=2)
391+
392+
logger.info("Results saved successfully")
393+
return json_results
460394

461395

462396
def setup_logging(verbose: bool = False) -> None:

tests/main_tests/test_main.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import unittest
22
import shutil
33
import json
4+
import pandas as pd
45
from pathlib import Path
5-
from causal_testing.main import CausalTestingPaths, CausalTestingFramework
6+
from causal_testing.main import CausalTestingPaths, CausalTestingFramework, parse_args
67

78

89
class TestCausalTestingPaths(unittest.TestCase):
@@ -44,18 +45,70 @@ def setUp(self):
4445
self.data_paths = ["tests/resources/data/data.csv"]
4546
self.test_config_path = "tests/resources/data/tests.json"
4647
self.output_path = Path("results/results.json")
47-
48-
def test_ctf(self):
49-
# Create paths object
50-
paths = CausalTestingPaths(
48+
self.paths = CausalTestingPaths(
5149
dag_path=self.dag_path,
5250
data_paths=self.data_paths,
5351
test_config_path=self.test_config_path,
5452
output_path=self.output_path,
5553
)
5654

57-
# Create and setup framework
58-
framework = CausalTestingFramework(paths)
55+
def test_load_data(self):
56+
csv_framework = CausalTestingFramework(self.paths)
57+
csv_df = csv_framework.load_data()
58+
59+
pqt_framework = CausalTestingFramework(
60+
CausalTestingPaths(
61+
dag_path=self.dag_path,
62+
data_paths=["tests/resources/data/data.pqt"],
63+
test_config_path=self.test_config_path,
64+
output_path=self.output_path,
65+
)
66+
)
67+
pqt_df = pqt_framework.load_data()
68+
pd.testing.assert_frame_equal(csv_df, pqt_df)
69+
70+
def test_load_data_invalid(self):
71+
framework = CausalTestingFramework(
72+
CausalTestingPaths(
73+
dag_path=self.dag_path,
74+
data_paths=[self.dag_path],
75+
test_config_path=self.test_config_path,
76+
output_path=self.output_path,
77+
)
78+
)
79+
with self.assertRaises(ValueError):
80+
framework.load_data()
81+
82+
def test_load_data_query(self):
83+
framework = CausalTestingFramework(self.paths)
84+
self.assertFalse((framework.load_data()["test_input"] > 4).all())
85+
self.assertTrue((framework.load_data("test_input > 4")["test_input"] > 4).all())
86+
87+
def test_load_dag_missing_node(self):
88+
framework = CausalTestingFramework(self.paths)
89+
framework.setup()
90+
framework.dag.graph.add_node("missing")
91+
with self.assertRaises(ValueError):
92+
framework.create_variables()
93+
94+
def test_create_base_test_case_missing_treatment(self):
95+
framework = CausalTestingFramework(self.paths)
96+
framework.setup()
97+
with self.assertRaises(KeyError) as e:
98+
framework.create_base_test(
99+
{"treatment_variable": "missing", "expected_effect": {"test_outcome": "NoEffect"}}
100+
)
101+
self.assertEqual("\"Treatment variable 'missing' not found in inputs or outputs\"", str(e.exception))
102+
103+
def test_create_base_test_case_missing_outcome(self):
104+
framework = CausalTestingFramework(self.paths)
105+
framework.setup()
106+
with self.assertRaises(KeyError) as e:
107+
framework.create_base_test({"treatment_variable": "test_input", "expected_effect": {"missing": "NoEffect"}})
108+
self.assertEqual("\"Outcome variable 'missing' not found in inputs or outputs\"", str(e.exception))
109+
110+
def test_ctf(self):
111+
framework = CausalTestingFramework(self.paths)
59112
framework.setup()
60113

61114
# Load and run tests

tests/resources/data/data.pqt

3.81 KB
Binary file not shown.

tests/resources/data/tests.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,14 @@
77
"effect_modifiers": [],
88
"expected_effect": {"test_output": "NoEffect"},
99
"skip": false
10+
},
11+
{
12+
"name": "test2",
13+
"treatment_variable": "test_input",
14+
"estimator": "LinearRegressionEstimator",
15+
"estimate_type": "coefficient",
16+
"effect_modifiers": [],
17+
"expected_effect": {"test_output": "NoEffect"},
18+
"skip": true
1019
}]
1120
}

0 commit comments

Comments
 (0)