Skip to content

Commit e676dc2

Browse files
committed
Some frontend improvements
1 parent c3b5ccb commit e676dc2

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

causal_testing/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main() -> None:
3232

3333
# Load and run tests
3434
framework.load_tests()
35-
results = framework.run_tests()
35+
results = framework.run_tests(silent=args.silent)
3636

3737
# Save results
3838
framework.save_results(results)

causal_testing/main.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from dataclasses import dataclass
77
from pathlib import Path
88
from typing import Dict, Any, Optional, List, Union, Sequence
9+
from tqdm import tqdm
910

1011
import pandas as pd
1112

@@ -16,11 +17,12 @@
1617
from .testing.causal_test_case import CausalTestCase
1718
from .testing.base_test_case import BaseTestCase
1819
from .testing.causal_test_outcome import NoEffect, SomeEffect, Positive, Negative
19-
from .testing.causal_test_result import CausalTestResult
20+
from .testing.causal_test_result import CausalTestResult, TestValue
2021
from .estimation.linear_regression_estimator import LinearRegressionEstimator
2122
from .estimation.logistic_regression_estimator import LogisticRegressionEstimator
2223

2324
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.ERROR)
2426

2527

2628
@dataclass
@@ -338,8 +340,6 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
338340
if estimator_class is None:
339341
raise ValueError(f"Unknown estimator: {test['estimator']}")
340342

341-
print(test)
342-
343343
# Create the estimator with correct parameters
344344
estimator = estimator_class(
345345
base_test_case=base_test,
@@ -366,7 +366,7 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
366366
estimator=estimator,
367367
)
368368

369-
def run_tests(self) -> List[CausalTestResult]:
369+
def run_tests(self, silent=False) -> List[CausalTestResult]:
370370
"""
371371
Run all test cases and return their results.
372372
@@ -380,14 +380,21 @@ def run_tests(self) -> List[CausalTestResult]:
380380
raise ValueError("No tests loaded. Call load_tests() first.")
381381

382382
results = []
383-
for test_case in self.test_cases:
383+
for test_case in tqdm(self.test_cases):
384384
try:
385385
result = test_case.execute_test()
386386
results.append(result)
387387
logger.info(f"Test completed: {test_case}")
388388
except Exception as e:
389-
logger.error(f"Error running test {test_case}: {str(e)}")
390-
raise
389+
if silent:
390+
logger.error(f"Error running test {test_case}: {str(e)}")
391+
raise
392+
result = CausalTestResult(
393+
estimator=test_case.estimator,
394+
test_value=TestValue("Error", str(e)),
395+
)
396+
results.append(result)
397+
logger.info(f"Test errored: {test_case}")
391398

392399
return results
393400

@@ -463,12 +470,19 @@ def setup_logging(verbose: bool = False) -> None:
463470
def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
464471
"""Parse command line arguments."""
465472
parser = argparse.ArgumentParser(description="Causal Testing Framework")
466-
parser.add_argument("--dag_path", help="Path to the DAG file (.dot)", required=True)
467-
parser.add_argument("--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
468-
parser.add_argument("--test_config", help="Path to test configuration file (.json)", required=True)
469-
parser.add_argument("--output", help="Path for output file (.json)", required=True)
470-
parser.add_argument("--verbose", help="Enable verbose logging", action="store_true")
471-
parser.add_argument("--ignore-cycles", help="Ignore cycles in DAG", action="store_true")
472-
parser.add_argument("--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
473+
parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
474+
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
475+
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
476+
parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
477+
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
478+
parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
479+
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
480+
parser.add_argument(
481+
"-s",
482+
"--silent",
483+
action="store_true",
484+
help="Do not crash on error. If set to true, errors are recorded as test results.",
485+
default=False,
486+
)
473487

474488
return parser.parse_args(args)

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class CausalDAG(nx.DiGraph):
132132

133133
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
134134
super().__init__(**attr)
135+
self.ignore_cycles = ignore_cycles
135136
if dot_path:
136137
with open(dot_path, "r", encoding="utf-8") as file:
137138
dot_content = file.read().replace("\n", "")
@@ -556,6 +557,8 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
556557
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
557558
estimate as opposed to a purely associational estimate.
558559
"""
560+
if self.ignore_cycles:
561+
return self.graph.predecessors(base_test_case.treatment_variable.name)
559562
minimal_adjustment_sets = []
560563
if base_test_case.effect == "total":
561564
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(

0 commit comments

Comments
 (0)