Skip to content

Commit df15faa

Browse files
committed
Part way to fixing categoricals
1 parent b53b8d8 commit df15faa

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def _generate_concrete_tests(
9393
var = self.scenario.variables[name]
9494
samples[var.name] = lhsmdu.inverseTransformSample(var.distribution, samples[var.name])
9595

96-
print(samples["ego_vehicle"])
9796
for index, row in samples.iterrows():
9897
optimizer = z3.Optimize()
9998
for c in self.scenario.constraints:

causal_testing/specification/variable.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def z3_types(datatype):
2222
if datatype in types:
2323
return types[datatype]
2424
if issubclass(datatype, Enum):
25-
dtype, _ = EnumSort(datatype.__name__, [x.name for x in datatype])
25+
dtype, _ = EnumSort(datatype.__name__, [x.value for x in datatype])
2626
return lambda x: Const(x, dtype)
2727
if hasattr(datatype, "to_z3"):
2828
return datatype.to_z3()
@@ -162,15 +162,14 @@ def cast(self, val: Any) -> T:
162162
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
163163
return self.datatype(val)
164164
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
165-
return self.datatype[str(val)]
165+
return self.datatype(str(val))
166166
return self.datatype(str(val))
167167

168168
def z3_val(self, z3_var, val: Any) -> T:
169169
native_val = self.cast(val)
170-
print(val, type(val), native_val, type(native_val))
171170
if isinstance(native_val, Enum):
172171
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
173-
values = [v for v in values if type(val)[str(v)] == val]
172+
values = [v for v in values if type(val)(str(v)) == val]
174173
assert len(values) == 1, f"Expected {values} to be length 1"
175174
return values[0]
176175
return native_val

causal_testing/testing/estimators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,9 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
453453
cols += [x for x in self.adjustment_set if x not in cols]
454454
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
455455
outcome_col = reduced_df[list(self.outcome)]
456+
for col in treatment_and_adjustments_cols:
457+
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
458+
treatment_and_adjustments_cols = pd.get_dummies(treatment_and_adjustments_cols, columns=[col], drop_first=True)
456459
regression = sm.OLS(outcome_col, treatment_and_adjustments_cols)
457460
model = regression.fit()
458461
return model

0 commit comments

Comments
 (0)