Skip to content

Commit c255ce0

Browse files
refactor estimate_coefficent to only return pd.Series
1 parent 0eefd4b commit c255ce0

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

causal_testing/testing/estimators.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,19 +351,16 @@ def estimate_coefficient(self) -> float:
351351
"""
352352
model = self._run_linear_regression()
353353
newline = "\n"
354-
treatment = [self.treatment]
355354
if self.treatment in self.df.dtypes and str(self.df.dtypes[self.treatment]) == "object":
356355
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
357356
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
357+
else:
358+
treatment = [self.treatment]
358359
assert set(treatment).issubset(
359360
model.params.index.tolist()
360361
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
361362
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
362363
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
363-
if self.treatment not in self.df.dtypes or str(self.df.dtypes[self.treatment]) != "object":
364-
unit_effect = unit_effect[0]
365-
ci_low = ci_low[0]
366-
ci_high = ci_high[0]
367364
return unit_effect, [ci_low, ci_high]
368365

369366
def estimate_ate(self) -> tuple[float, list[float, float], float]:

0 commit comments

Comments
 (0)