Skip to content

Commit 31ca261

Browse files
committed
Change branch
1 parent c17c46f commit 31ca261

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

causal_testing/main.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def load_dag(self) -> CausalDAG:
138138
logger.error(f"Failed to load DAG: {str(e)}")
139139
raise
140140

141+
def _read_dataframe(self, data_path):
142+
if str(data_path).endswith(".csv"):
143+
return pd.read_csv(data_path)
144+
if str(data_path).endswith(".pqt"):
145+
return pd.read_parquet(data_path)
146+
raise ValueError(f"Invalid file type {data_path}. Can only read CSV (.csv) or parquet (.pqt) files.")
147+
141148
def load_data(self, query: Optional[str] = None) -> pd.DataFrame:
142149
"""Load and combine all data sources with optional filtering.
143150
@@ -148,7 +155,7 @@ def load_data(self, query: Optional[str] = None) -> pd.DataFrame:
148155
logger.info(f"Loading data from {len(self.paths.data_paths)} source(s)")
149156

150157
try:
151-
dfs = [pd.read_csv(data_path) for data_path in self.paths.data_paths]
158+
dfs = [self._read_dataframe(data_path) for data_path in self.paths.data_paths]
152159
data = pd.concat(dfs, axis=0, ignore_index=True)
153160
logger.info(f"Initial data shape: {data.shape}")
154161

@@ -171,20 +178,19 @@ def create_variables(self) -> None:
171178
172179
173180
"""
174-
for node in self.dag.graph.nodes():
175-
dtype = self.data[node].dtype.type if node in self.data.columns else str
181+
for node_name, node_data in self.dag.graph.nodes(data=True):
182+
if node_name not in self.data.columns and not node_data.get("hidden", False):
183+
raise ValueError(f"Node {node_name} missing from data. Should it be marked as hidden?")
176184

177-
# If node has no incoming edges, it's an input
178-
if self.dag.graph.in_degree(node) == 0:
179-
self.variables["inputs"][node] = Input(name=node, datatype=dtype)
185+
dtype = self.data.dtypes.get(node_name)
180186

181-
# If node has outgoing edges, it can be an input
182-
if self.dag.graph.out_degree(node) > 0:
183-
self.variables["inputs"][node] = Input(name=node, datatype=dtype)
187+
# If node has no incoming edges, it's an input
188+
if self.dag.graph.in_degree(node_name) == 0:
189+
self.variables["inputs"][node_name] = Input(name=node_name, datatype=dtype)
184190

185-
# If node has incoming edges, it can be an output
186-
if self.dag.graph.in_degree(node) > 0:
187-
self.variables["outputs"][node] = Output(name=node, datatype=dtype)
191+
# Otherwise it's an output
192+
if self.dag.graph.in_degree(node_name) > 0:
193+
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)
188194

189195
def create_scenario_and_specification(self) -> None:
190196
"""Create scenario and causal specification objects from loaded data.
@@ -259,7 +265,7 @@ def create_base_test(self, test: dict) -> BaseTestCase:
259265
:return: BaseTestCase object
260266
:raises: KeyError if required variables are not found in inputs or outputs
261267
"""
262-
treatment_name = test["mutations"][0]
268+
treatment_name = test["treatment_variable"]
263269
outcome_name = next(iter(test["expected_effect"].keys()))
264270

265271
# Look for treatment variable in both inputs and outputs
@@ -333,12 +339,11 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
333339
raise ValueError(f"Unknown estimator: {test['estimator']}")
334340

335341
# Create the estimator with correct parameters
336-
adjustment_set = self.causal_specification.causal_dag.identification(base_test)
337342
estimator = estimator_class(
338343
base_test_case=base_test,
339-
treatment_value=1.0, # hardcode these for now
340-
control_value=0.0,
341-
adjustment_set=adjustment_set,
344+
treatment_value=test.get("treatment_value"),
345+
control_value=test.get("control_value"),
346+
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
342347
df=self.data,
343348
effect_modifiers=None,
344349
formula=test.get("formula"),
@@ -416,7 +421,7 @@ def save_results(self, results: List[CausalTestResult]) -> None:
416421
"name": test_config["name"],
417422
"estimate_type": test_config["estimate_type"],
418423
"effect": test_config.get("effect", "direct"),
419-
"mutations": test_config["mutations"],
424+
"treatment_variable": test_config["treatment_variable"],
420425
"expected_effect": test_config["expected_effect"],
421426
"formula": test_config.get("formula"),
422427
"alpha": test_config.get("alpha", 0.05),

0 commit comments

Comments
 (0)