Skip to content

Functionality for test-specific queries #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 1, 2025
15 changes: 12 additions & 3 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,26 @@
if estimator_class is None:
raise ValueError(f"Unknown estimator: {test['estimator']}")

# Filter the data if a test-specific query exists, otherwise use the original data
test_query = test.get("query", self.query)

if test_query:
logger.info(f"Applying test-specific query for '{test['name']}': {test_query}")
filtered_df = self.data.query(test_query)

Check warning on line 289 in causal_testing/main.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/main.py#L288-L289

Added lines #L288 - L289 were not covered by tests
else:
filtered_df = self.data

# Create the estimator with correct parameters
estimator = estimator_class(
base_test_case=base_test,
treatment_value=test.get("treatment_value"),
control_value=test.get("control_value"),
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
df=self.data,
df=filtered_df,
effect_modifiers=None,
formula=test.get("formula"),
alpha=test.get("alpha", 0.05),
query="",
query=test_query,
)

# Get effect type and create expected effect
Expand Down Expand Up @@ -424,7 +433,7 @@
"effect": test_config.get("effect", "direct"),
"treatment_variable": test_config["treatment_variable"],
"expected_effect": test_config["expected_effect"],
"formula": test_config.get("formula"),
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
"alpha": test_config.get("alpha", 0.05),
"skip": test_config.get("skip", False),
"passed": test_passed,
Expand Down