Skip to content

Commit 2749923

Browse files
committed
test __main__
1 parent 97f6f07 commit 2749923

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

causal_testing/__main__.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,26 @@ def main() -> None:
1717
# Setup logging
1818
setup_logging(args.verbose)
1919

20-
try:
21-
# Create paths object
22-
paths = CausalTestingPaths(
23-
dag_path=args.dag_path,
24-
data_paths=args.data_paths,
25-
test_config_path=args.test_config,
26-
output_path=args.output,
27-
)
28-
29-
# Create and setup framework
30-
framework = CausalTestingFramework(paths, ignore_cycles=args.ignore_cycles, query=args.query)
31-
framework.setup()
32-
33-
# Load and run tests
34-
framework.load_tests()
35-
results = framework.run_tests(silent=args.silent)
36-
37-
# Save results
38-
framework.save_results(results)
39-
40-
logging.info("Causal testing completed successfully.")
41-
42-
except Exception as e:
43-
logging.error("Error during causal testing: %s", str(e))
44-
raise
20+
# Create paths object
21+
paths = CausalTestingPaths(
22+
dag_path=args.dag_path,
23+
data_paths=args.data_paths,
24+
test_config_path=args.test_config,
25+
output_path=args.output,
26+
)
27+
28+
# Create and setup framework
29+
framework = CausalTestingFramework(paths, ignore_cycles=args.ignore_cycles, query=args.query)
30+
framework.setup()
31+
32+
# Load and run tests
33+
framework.load_tests()
34+
results = framework.run_tests(silent=args.silent)
35+
36+
# Save results
37+
framework.save_results(results)
38+
39+
logging.info("Causal testing completed successfully.")
4540

4641

4742
if __name__ == "__main__":

tests/main_tests/test_main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pandas as pd
55
from pathlib import Path
66
from causal_testing.main import CausalTestingPaths, CausalTestingFramework, parse_args
7+
from causal_testing.__main__ import main
8+
from unittest.mock import patch
79

810

911
class TestCausalTestingPaths(unittest.TestCase):
@@ -128,6 +130,24 @@ def test_ctf(self):
128130

129131
self.assertEqual(tests_passed, [True])
130132

133+
def test_parse_args(self):
134+
with unittest.mock.patch(
135+
"sys.argv",
136+
[
137+
"causal_testing",
138+
"--dag_path",
139+
str(self.dag_path),
140+
"--data_paths",
141+
str(self.data_paths[0]),
142+
"--test_config",
143+
str(self.test_config_path),
144+
"--output",
145+
str(self.output_path.parent / "main.json"),
146+
],
147+
):
148+
main()
149+
self.assertTrue((self.output_path.parent / "main.json").exists())
150+
131151
def tearDown(self):
132152
if self.output_path.parent.exists():
133153
shutil.rmtree(self.output_path.parent)

0 commit comments

Comments
 (0)