Skip to content

Commit 20f39ef

Browse files
committed
Fixed various issues with JSON frontend and estimators
1 parent 9105d31 commit 20f39ef

File tree

5 files changed

+50
-38
lines changed

5 files changed

+50
-38
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
expected_causal_effect: dict[Variable:CausalTestOutcome],
3232
effect_modifiers: set[Variable] = None,
3333
estimate_type: str = "ate",
34+
effect: str = "total"
3435
):
3536
assert treatment_variable in scenario.variables.values(), (
3637
"Treatment variables must be a subset of variables."
@@ -44,6 +45,7 @@ def __init__(
4445
self.treatment_variable = treatment_variable
4546
self.expected_causal_effect = expected_causal_effect
4647
self.estimate_type = estimate_type
48+
self.effect = effect
4749

4850
if effect_modifiers is not None:
4951
self.effect_modifiers = effect_modifiers
@@ -103,6 +105,7 @@ def _generate_concrete_tests(
103105
for c in self.intervention_constraints:
104106
optimizer.assert_and_track(c, str(c))
105107

108+
106109
for v in run_columns:
107110
optimizer.add_soft(self.scenario.variables[v].z3 == self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v]))
108111

@@ -124,6 +127,7 @@ def _generate_concrete_tests(
124127
outcome_variables=list(self.expected_causal_effect.keys()),
125128
estimate_type=self.estimate_type,
126129
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
130+
effect=self.effect
127131
)
128132

129133
for v in self.scenario.inputs():
@@ -134,19 +138,20 @@ def _generate_concrete_tests(
134138
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
135139
)
136140

137-
concrete_tests.append(concrete_test)
138-
# Control run
139-
control_run = {
140-
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
141-
}
142-
control_run["bin"] = index
143-
runs.append(control_run)
144-
# Treatment run
145-
if rct:
146-
treatment_run = control_run.copy()
147-
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
148-
treatment_run["bin"] = index
149-
runs.append(treatment_run)
141+
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
142+
concrete_tests.append(concrete_test)
143+
# Control run
144+
control_run = {
145+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
146+
}
147+
control_run["bin"] = index
148+
runs.append(control_run)
149+
# Treatment run
150+
if rct:
151+
treatment_run = control_run.copy()
152+
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
153+
treatment_run["bin"] = index
154+
runs.append(treatment_run)
150155

151156
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
152157

@@ -185,7 +190,9 @@ def generate_concrete_tests(
185190
pre_break = False
186191
for i in range(hard_max):
187192
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
188-
concrete_tests += concrete_tests_
193+
for t_ in concrete_tests_:
194+
if not any([vars(t_) == vars(t) for t in concrete_tests]):
195+
concrete_tests.append(t_)
189196
runs = pd.concat([runs, runs_])
190197
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
191198

@@ -212,11 +219,13 @@ def generate_concrete_tests(
212219
for var in effect_modifier_configs.columns
213220
}
214221
)
215-
print("=== test ===")
216222
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
217223
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
218224

219-
if issubclass(self.treatment_variable.datatype, Enum) and set(zip(control_values, treatment_values)).issubset(itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)):
225+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(set(zip(control_values, treatment_values))):
226+
pre_break = True
227+
break
228+
if issubclass(self.treatment_variable.datatype, Enum) and set(itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)).issubset(zip(control_values, treatment_values)):
220229
pre_break = True
221230
break
222231
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
@@ -225,9 +234,9 @@ def generate_concrete_tests(
225234

226235
if target_ks_score is not None and not pre_break:
227236
logger.error(
228-
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
229-
hard_max,
237+
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
230238
target_ks_score,
231239
ks_stats,
240+
len(concrete_tests)
232241
)
233242
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7575
:param metas:
7676
"""
7777
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
78-
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
78+
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
7979
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8080

8181
def setup(self):
@@ -103,6 +103,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
103103
if "effect_modifiers" in test
104104
else {},
105105
estimate_type=test["estimate_type"],
106+
effect=test.get("effect", "total")
106107
)
107108
return abstract_test
108109

@@ -126,7 +127,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
126127
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
127128
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
128129
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
129-
logger.info("%s/%s failed", failures, len(concrete_tests))
130+
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
130131

131132
def _execute_tests(self, concrete_tests, estimators, test, f_flag):
132133
failures = 0
@@ -153,11 +154,12 @@ def _populate_metas(self):
153154
meta.populate(self.data)
154155

155156
for var in self.metas + self.outputs:
156-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
157-
fitter.fit()
158-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
159-
var.distribution = getattr(scipy.stats, dist)(**params)
160-
logger.info(var.name + f"{dist}({params})")
157+
if not var.distribution:
158+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
159+
fitter.fit()
160+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
161+
var.distribution = getattr(scipy.stats, dist)(**params)
162+
logger.info(var.name + f" {dist}({params})")
161163

162164
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
163165
"""Executes a singular test case, prints the results and returns the test case result
@@ -180,7 +182,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
180182
if causal_test_result.ci_low() and causal_test_result.ci_high():
181183
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
182184
else:
183-
result_string = causal_test_result.test_value.value
185+
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
184186
if f_flag:
185187
assert test_passes, (
186188
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
@@ -191,7 +193,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
191193
logger.warning(
192194
" FAILED- expected %s, got %s",
193195
causal_test_case.expected_causal_effect,
194-
causal_test_result.test_value.value,
196+
result_string
195197
)
196198
return failed
197199

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
# min_seps.remove(set(outcomes))
258+
min_seps.remove(set(outcomes))
259259
return min_seps
260260

261261
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
@@ -278,6 +278,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
278278
:param outcomes: A list of strings representing outcomes.
279279
:return: A list of strings representing the minimal adjustment set.
280280
"""
281+
281282
# 1. Construct the proper back-door graph's ancestor moral graph
282283
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
283284
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
@@ -316,6 +317,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
316317
for adj in minimum_adjustment_sets
317318
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
318319
]
320+
319321
return valid_minimum_adjustment_sets
320322

321323
def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:

causal_testing/specification/variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,13 @@ def cast(self, val: Any) -> T:
155155
assert val is not None, f"Invalid value None for variable {self}"
156156
if isinstance(val, self.datatype):
157157
return val
158+
if isinstance(val, BoolRef) and self.datatype == bool:
159+
return str(val) == "True"
158160
if isinstance(val, RatNumRef) and self.datatype == float:
159161
return float(val.numerator().as_long() / val.denominator().as_long())
160162
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
161163
return val.as_string()
162-
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
164+
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (self.datatype == int or self.datatype == float or self.datatype == bool):
163165
return self.datatype(val)
164166
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
165167
return self.datatype(str(val))

causal_testing/testing/estimators.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,20 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
361361
:return: The average treatment effect and the 95% Wald confidence intervals.
362362
"""
363363
model = self._run_linear_regression()
364+
364365
# Create an empty individual for the control and treated
365366
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
366-
individuals.loc["control", list(self.treatment)] = self.control_values
367-
individuals.loc["treated", list(self.treatment)] = self.treatment_values
368367
# This is a temporary hack
369368
for t in self.square_terms:
370369
individuals[t + "^2"] = individuals[t] ** 2
371370
for a, b in self.product_terms:
372371
individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
373372

373+
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
374+
# the effect with "ate = t_test_results.effect[0]"
375+
individuals.loc["control", list(self.treatment)] = self.control_values
376+
individuals.loc["treated", list(self.treatment)] = self.treatment_values
377+
374378
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
375379
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
376380
ate = t_test_results.effect[0]
@@ -385,7 +389,6 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
385389
"""
386390
model = self._run_linear_regression()
387391
self.model = model
388-
print(model.summary())
389392

390393

391394
x = pd.DataFrame()
@@ -399,18 +402,12 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
399402
x["1/" + t] = 1 / x[t]
400403
for a, b in self.product_terms:
401404
x[f"{a}*{b}"] = x[a] * x[b]
402-
print("full")
403-
print(x)
404405
for col in x:
405406
if str(x.dtypes[col]) == "object":
406407
x = pd.get_dummies(x, columns=[col], drop_first=True)
407-
print("dummy")
408-
print(x)
409408
x = x[model.params.index]
410-
411409
y = model.get_prediction(x).summary_frame()
412410

413-
print("control", y.iloc[1], "treatment", y.iloc[0])
414411
return y.iloc[1], y.iloc[0]
415412

416413
def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:

0 commit comments

Comments
 (0)