Skip to content

Commit b8b1b2b

Browse files
committed
Fix individual policy with no heterogeneity
1 parent 2c140f2 commit b8b1b2b

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

econml/solutions/causal_analysis/_causal_analysis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,7 @@ def individualized_policy(self, Xtest, feature_index, *, n_rows=None, treatment_
16001600

16011601
Xtest = result.X_transformer.transform(Xtest)
16021602
if Xtest.shape[1] == 0:
1603+
x_rows = Xtest.shape[0]
16031604
Xtest = None
16041605

16051606
if result.feature_baseline is None:
@@ -1608,6 +1609,9 @@ def individualized_policy(self, Xtest, feature_index, *, n_rows=None, treatment_
16081609
else:
16091610
effect = result.estimator.const_marginal_effect_inference(Xtest)
16101611

1612+
if Xtest is None: # we got a scalar effect although our original X may have had more rows
1613+
effect = effect._expand_outputs(x_rows)
1614+
16111615
multi_y = (not self._vec_y) or self.classification
16121616

16131617
if multi_y and result.feature_baseline is not None and np.ndim(treatment_costs) == 2:

econml/tests/test_causal_analysis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def test_empty_hinds(self):
571571
eff = ca.local_causal_effect(X_df, alpha=0.05)
572572
for ind in feat_inds:
573573
pto = ca._policy_tree_output(X_df, ind)
574+
ca._individualized_policy_dict(X_df, ind)
574575

575576
def test_can_serialize(self):
576577
import pickle

0 commit comments

Comments
 (0)