Skip to content

Commit 7611771

Browse files
committed
Pushing so others can see - VERY BROKEN
1 parent bc13d84 commit 7611771

File tree

4 files changed

+36
-41
lines changed

4 files changed

+36
-41
lines changed

causal_testing/json_front/json_class.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
5858

59-
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
59+
def set_paths(self, json_path: str, dag_path: str, data_paths: str=[]):
6060
"""
6161
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6262
:param json_path: string path representation to .json file containing test specifications
@@ -73,7 +73,12 @@ def setup(self, scenario: Scenario):
7373
self.causal_specification = CausalSpecification(
7474
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
7575
)
76-
self._json_parse()
76+
# Parse the JSON test plan
77+
with open(self.input_paths.json_path, encoding="utf-8") as f:
78+
self.test_plan = json.load(f)
79+
# Populate the data
80+
if self.input_paths.data_paths:
81+
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
7782
self._populate_metas()
7883

7984
def _create_abstract_test_case(self, test, mutates, effects):
@@ -144,6 +149,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
144149
+ "==============\n"
145150
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
146151
)
152+
print(msg)
147153
else:
148154
abstract_test = self._create_abstract_test_case(test, mutates, effects)
149155
concrete_tests, _ = abstract_test.generate_concrete_tests(5, 0.05)
@@ -198,15 +204,6 @@ def _execute_tests(self, concrete_tests, test, f_flag):
198204
failures += 1
199205
return failures, details
200206

201-
def _json_parse(self):
202-
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
203-
with open(self.input_paths.json_path, encoding="utf-8") as f:
204-
self.test_plan = json.load(f)
205-
for data_file in self.input_paths.data_paths:
206-
df = pd.read_csv(data_file, header=0)
207-
self.data.append(df)
208-
self.data = pd.concat(self.data)
209-
210207
def _populate_metas(self):
211208
"""
212209
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
@@ -236,7 +233,7 @@ def _execute_test_case(
236233

237234
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
238235

239-
if causal_test_result.ci_low() and causal_test_result.ci_high():
236+
if causal_test_result.ci_low() is not None and causal_test_result.ci_high() is not None:
240237
result_string = (
241238
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
242239
f"{causal_test_result.ci_high()}"
@@ -351,7 +348,6 @@ def get_args(test_args=None) -> argparse.Namespace:
351348
parser.add_argument(
352349
"--data_path",
353350
help="Specify path to file containing runtime data",
354-
required=True,
355351
nargs="+",
356352
)
357353
parser.add_argument(

causal_testing/testing/causal_test_outcome.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class SomeEffect(CausalTestOutcome):
2727

2828
def apply(self, res: CausalTestResult) -> bool:
2929
if res.test_value.type in {"ate", "coefficient"}:
30-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
30+
return any([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())])
31+
# return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
3132
if res.test_value.type == "risk_ratio":
3233
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3334
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
@@ -36,10 +37,11 @@ def apply(self, res: CausalTestResult) -> bool:
3637
class NoEffect(CausalTestOutcome):
3738
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
3839

39-
def apply(self, res: CausalTestResult) -> bool:
40+
def apply(self, res: CausalTestResult, threshold: float = 1e-10) -> bool:
4041
print("RESULT", res)
4142
if res.test_value.type in {"ate", "coefficient"}:
42-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
43+
return all([ci_low < 0< ci_high for ci_low, ci_high in zip(res.ci_low(), res.ci_high())]) or all([abs(v) < 1e-10 for v in res.test_value.value])
44+
# return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
4345
if res.test_value.type == "risk_ratio":
4446
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
4547
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/causal_test_result.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,20 @@ def __init__(
4343
self.effect_modifier_configuration = {}
4444

4545
def __str__(self):
46+
def push(s, inc=" "):
47+
return inc + str(s).replace("\n", "\n"+inc)
4648
base_str = (
4749
f"Causal Test Result\n==============\n"
4850
f"Treatment: {self.estimator.treatment}\n"
4951
f"Control value: {self.estimator.control_value}\n"
5052
f"Treatment value: {self.estimator.treatment_value}\n"
5153
f"Outcome: {self.estimator.outcome}\n"
5254
f"Adjustment set: {self.adjustment_set}\n"
53-
f"{self.test_value.type}: {self.test_value.value}\n"
55+
f"{self.test_value.type}:\n{push(self.test_value.value)}\n"
5456
)
5557
confidence_str = ""
5658
if self.confidence_intervals:
57-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
59+
confidence_str += f"Confidence intervals:\n{push(pd.DataFrame(self.confidence_intervals).transpose().to_string(header=False))}\n"
5860
return base_str + confidence_str
5961

6062
def to_dict(self):
@@ -76,14 +78,14 @@ def to_dict(self):
7678

7779
def ci_low(self):
7880
"""Return the lower bracket of the confidence intervals."""
79-
if self.confidence_intervals and all(self.confidence_intervals):
80-
return min(self.confidence_intervals)
81+
if self.confidence_intervals:
82+
return self.confidence_intervals[0]
8183
return None
8284

8385
def ci_high(self):
8486
"""Return the higher bracket of the confidence intervals."""
85-
if self.confidence_intervals and all(self.confidence_intervals):
86-
return max(self.confidence_intervals)
87+
if self.confidence_intervals:
88+
return self.confidence_intervals[1]
8789
return None
8890

8991
def ci_valid(self) -> bool:

causal_testing/testing/estimators.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,15 @@ def estimate_unit_ate(self) -> float:
335335
:return: The unit average treatment effect and the 95% Wald confidence intervals.
336336
"""
337337
model = self._run_linear_regression()
338-
assert self.treatment in model.params, f"{self.treatment} not in {model.params}"
339-
unit_effect = model.params[[self.treatment]].values[0] # Unit effect is the coefficient of the treatment
340-
[ci_low, ci_high] = self._get_confidence_intervals(model)
341-
338+
newline = "\n"
339+
print(model.conf_int())
340+
treatment = [self.treatment]
341+
if str(self.df.dtypes[self.treatment]) == "object":
342+
reference = min(self.df[self.treatment])
343+
treatment = [x.replace("[", "[T.") for x in dmatrix(f"{self.treatment}-1", self.df.query(f"{self.treatment} != '{reference}'"), return_type="dataframe").columns]
344+
assert set(treatment).issubset(model.params.index.tolist()), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"
345+
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
346+
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
342347
return unit_effect, [ci_low, ci_high]
343348

344349
def estimate_ate(self) -> tuple[float, list[float, float], float]:
@@ -353,19 +358,14 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
353358
# Create an empty individual for the control and treated
354359
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
355360

356-
# This is a temporary hack
357-
# for t in self.square_terms:
358-
# individuals[t + "^2"] = individuals[t] ** 2
359-
# for a, b in self.product_terms:
360-
# individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
361-
362361
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
363362
# the effect with "ate = t_test_results.effect[0]"
364363
individuals.loc["control", [self.treatment]] = self.control_value
365364
individuals.loc["treated", [self.treatment]] = self.treatment_value
366365

367366
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
368367
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
368+
print("t_test_results", t_test_results.effect)
369369
ate = t_test_results.effect[0]
370370
confidence_intervals = list(t_test_results.conf_int().flatten())
371371
return ate, confidence_intervals
@@ -473,21 +473,16 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
473473
cols = [self.treatment]
474474
cols += [x for x in self.adjustment_set if x not in cols]
475475
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
476-
for col in treatment_and_adjustments_cols:
477-
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
478-
treatment_and_adjustments_cols = pd.get_dummies(
479-
treatment_and_adjustments_cols, columns=[col], drop_first=True
480-
)
481476
model = smf.ols(formula=self.formula, data=self.df).fit()
482477
return model
483478

484-
def _get_confidence_intervals(self, model):
479+
def _get_confidence_intervals(self, model, treatment):
485480
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
486481
ci_low, ci_high = (
487-
confidence_intervals[0][[self.treatment]],
488-
confidence_intervals[1][[self.treatment]],
482+
confidence_intervals[0].loc[treatment],
483+
confidence_intervals[1].loc[treatment],
489484
)
490-
return [ci_low.values[0], ci_high.values[0]]
485+
return [ci_low, ci_high]
491486

492487

493488
class InstrumentalVariableEstimator(Estimator):

0 commit comments

Comments
 (0)