Skip to content

Commit 5210cf3

Browse files
committed
feat: functionality for test-specific queries
1 parent a5fa736 commit 5210cf3

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

causal_testing/main.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,17 +281,26 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
281281
if estimator_class is None:
282282
raise ValueError(f"Unknown estimator: {test['estimator']}")
283283

284+
# Filter the data if a test-specific query exists, otherwise use the original data
285+
test_query = test.get("query", self.query)
286+
287+
if test_query:
288+
logger.info(f"Applying test-specific query for '{test['name']}': {test_query}")
289+
filtered_df = self.data.query(test_query)
290+
else:
291+
filtered_df = self.data
292+
284293
# Create the estimator with correct parameters
285294
estimator = estimator_class(
286295
base_test_case=base_test,
287296
treatment_value=test.get("treatment_value"),
288297
control_value=test.get("control_value"),
289298
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
290-
df=self.data,
299+
df=filtered_df,
291300
effect_modifiers=None,
292301
formula=test.get("formula"),
293302
alpha=test.get("alpha", 0.05),
294-
query="",
303+
query=test_query,
295304
)
296305

297306
# Get effect type and create expected effect

0 commit comments

Comments
 (0)