Skip to content

Commit a37eb3c

Browse files
committed
Avoid returning defaultdict directly
1 parent e67bff7 commit a37eb3c

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

econml/_shap.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def cmd_func(X):
8282
data=shap_out.data, main_effects=shap_out.main_effects,
8383
feature_names=shap_out.feature_names)
8484
shap_outs[output_names[i]][treatment_names[0]] = shap_out_new
85-
return shap_outs
85+
# return plain dictionary so that erroneous accesses don't half work (see #708)
86+
return dict(shap_outs)
8687

8788

8889
def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, feature_names=None,
@@ -176,7 +177,8 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
176177
else:
177178
shap_outs[output_names[0]][treatment_names[i]] = shap_out
178179

179-
return shap_outs
180+
# return plain dictionary so that erroneous accesses don't half work (see #708)
181+
return dict(shap_outs)
180182

181183

182184
def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_intercept,
@@ -258,7 +260,8 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
258260
feature_names=shap_out.feature_names)
259261
shap_outs[output_names[0]][treatment_names[i]] = shap_out_new
260262

261-
return shap_outs
263+
# return plain dictionary so that erroneous accesses don't half work (see #708)
264+
return dict(shap_outs)
262265

263266

264267
def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t, d_y, featurizer=None,
@@ -352,7 +355,8 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
352355
shap_outs[output_names[j]][treatment_names[i]] = shap_out_new
353356
else:
354357
shap_outs[output_names[j]][treatment_names[0]] = shap_out
355-
return shap_outs
358+
# return plain dictionary so that erroneous accesses don't half work (see #708)
359+
return dict(shap_outs)
356360

357361

358362
def _define_names(d_t, d_y, treatment_names, output_names, feature_names, input_names, featurizer):

econml/utilities.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,9 @@ def transpose_dictionary(d):
13601360
for key1, value in d.items():
13611361
for key2, val in value.items():
13621362
output[key2][key1] = val
1363-
return output
1363+
1364+
# return plain dictionary so that erroneous accesses don't half work (see e.g. #708)
1365+
return dict(output)
13641366

13651367

13661368
def reshape_arrays_2dim(length, *args):

0 commit comments

Comments
 (0)