Skip to content

Functionality for test-specific queries #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 1, 2025
25 changes: 22 additions & 3 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,36 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
if estimator_class is None:
raise ValueError(f"Unknown estimator: {test['estimator']}")

# Handle combined queries (global and test-specific)
test_query = test.get("query")
combined_query = None

if self.query and test_query:
combined_query = f"({self.query}) and ({test_query})"
logger.info(
f"Combining global query '{self.query}' with test-specific query "
f"'{test_query}' for test '{test['name']}'"
)
elif test_query:
combined_query = test_query
logger.info(f"Using test-specific query for '{test['name']}': {test_query}")
elif self.query:
combined_query = self.query
logger.info(f"Using global query for '{test['name']}': {self.query}")

filtered_df = self.data.query(combined_query) if combined_query else self.data

# Create the estimator with correct parameters
estimator = estimator_class(
base_test_case=base_test,
treatment_value=test.get("treatment_value"),
control_value=test.get("control_value"),
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
df=self.data,
df=filtered_df,
effect_modifiers=None,
formula=test.get("formula"),
alpha=test.get("alpha", 0.05),
query="",
query=combined_query,
)

# Get effect type and create expected effect
Expand Down Expand Up @@ -424,7 +443,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
"effect": test_config.get("effect", "direct"),
"treatment_variable": test_config["treatment_variable"],
"expected_effect": test_config["expected_effect"],
"formula": test_config.get("formula"),
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
"alpha": test_config.get("alpha", 0.05),
"skip": test_config.get("skip", False),
"passed": test_passed,
Expand Down
61 changes: 61 additions & 0 deletions tests/main_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,67 @@ def test_ctf(self):

self.assertEqual(tests_passed, [True])

def test_global_query(self):
framework = CausalTestingFramework(self.paths)
framework.setup()

query_framework = CausalTestingFramework(self.paths, query="test_input > 0")
query_framework.setup()

self.assertTrue(len(query_framework.data) > 0)
self.assertTrue((query_framework.data["test_input"] > 0).all())

with open(self.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

test_config = test_configs["tests"][0].copy()
if "query" in test_config:
del test_config["query"]

base_test = query_framework.create_base_test(test_config)
causal_test = query_framework.create_causal_test(test_config, base_test)

self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())

query_framework.create_variables()
query_framework.create_scenario_and_specification()

self.assertIsNotNone(query_framework.scenario)
self.assertIsNotNone(query_framework.causal_specification)

def test_test_specific_query(self):
framework = CausalTestingFramework(self.paths)
framework.setup()

with open(self.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

test_config = test_configs["tests"][0].copy()
test_config["query"] = "test_input > 0"

base_test = framework.create_base_test(test_config)
causal_test = framework.create_causal_test(test_config, base_test)

self.assertTrue(len(causal_test.estimator.df) > 0)
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())

def test_combined_queries(self):
global_framework = CausalTestingFramework(self.paths, query="test_input > 0")
global_framework.setup()

with open(self.test_config_path, "r", encoding="utf-8") as f:
test_configs = json.load(f)

test_config = test_configs["tests"][0].copy()
test_config["query"] = "test_output > 0"

base_test = global_framework.create_base_test(test_config)
causal_test = global_framework.create_causal_test(test_config, base_test)

self.assertTrue(len(causal_test.estimator.df) > 0)
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
self.assertTrue((causal_test.estimator.df["test_output"] > 0).all())

def test_parse_args(self):
with unittest.mock.patch(
"sys.argv",
Expand Down
2 changes: 2 additions & 0 deletions tests/resources/data/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"estimator": "LinearRegressionEstimator",
"estimate_type": "coefficient",
"effect_modifiers": [],
"query": "test_input > 0",
"expected_effect": {"test_output": "NoEffect"},
"skip": false
},
Expand All @@ -14,6 +15,7 @@
"estimator": "LinearRegressionEstimator",
"estimate_type": "coefficient",
"effect_modifiers": [],
"query": "test_input <= 5",
"expected_effect": {"test_output": "NoEffect"},
"skip": true
}]
Expand Down