Skip to content

Commit 647927c

Browse files
committed
Cleaner
1 parent d1817b4 commit 647927c

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from enum import Enum
34

45
import pandas as pd
56
import z3
@@ -144,4 +145,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
144145
for meta in self.scenario.metas():
145146
meta.populate(execution_data_df)
146147
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
148+
for vname, var in self.scenario.variables.items():
149+
if issubclass(var.datatype, Enum):
150+
scenario_execution_data_df[vname] = [var.datatype(x) for x in scenario_execution_data_df[vname]]
147151
return scenario_execution_data_df

causal_testing/testing/estimators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def _run_logistic_regression(self) -> RegressionResultsWrapper:
142142
cols += [x for x in self.adjustment_set if x not in cols]
143143
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
144144
outcome_col = reduced_df[list(self.outcome)]
145+
for col in treatment_and_adjustments_cols:
146+
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
147+
treatment_and_adjustments_cols = pd.get_dummies(treatment_and_adjustments_cols, columns=[col], drop_first=True)
145148
regression = sm.Logit(outcome_col, treatment_and_adjustments_cols)
146149
model = regression.fit()
147150
return model
@@ -166,6 +169,10 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
166169
x["1/" + t] = 1 / x[t]
167170
for a, b in self.product_terms:
168171
x[f"{a}*{b}"] = x[a] * x[b]
172+
173+
for col in x:
174+
if str(x.dtypes[col]) == "object":
175+
x = pd.get_dummies(x, columns=[col], drop_first=True)
169176
x = x[model.params.index]
170177

171178
y = model.predict(x)
@@ -360,6 +367,8 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
360367
"""
361368
model = self._run_linear_regression()
362369
self.model = model
370+
print(model.summary())
371+
363372

364373
x = pd.DataFrame()
365374
x[self.treatment[0]] = [self.treatment_values, self.control_values]
@@ -376,13 +385,14 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
376385
print(x)
377386
for col in x:
378387
if str(x.dtypes[col]) == "object":
379-
x[col] = [v.value for v in x[]]
380388
x = pd.get_dummies(x, columns=[col], drop_first=True)
381389
print("dummy")
382390
print(x)
383391
x = x[model.params.index]
384392

385393
y = model.get_prediction(x).summary_frame()
394+
395+
print("control", y.iloc[1], "treatment", y.iloc[0])
386396
return y.iloc[1], y.iloc[0]
387397

388398
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
@@ -406,6 +416,7 @@ def estimate_ate_calculated(self) -> tuple[float, list[float, float]]:
406416
:return: The average treatment effect and the 95% Wald confidence intervals.
407417
"""
408418
control_outcome, treatment_outcome = self.estimate_control_treatment()
419+
assert False
409420
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
410421
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]
411422

@@ -461,8 +472,6 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
461472
cols += [x for x in self.adjustment_set if x not in cols]
462473
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
463474
outcome_col = reduced_df[list(self.outcome)]
464-
print("train_data")
465-
print(treatment_and_adjustments_cols)
466475
for col in treatment_and_adjustments_cols:
467476
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
468477
treatment_and_adjustments_cols = pd.get_dummies(treatment_and_adjustments_cols, columns=[col], drop_first=True)

0 commit comments

Comments
 (0)