@@ -82,7 +82,8 @@ def cmd_func(X):
8282 data = shap_out .data , main_effects = shap_out .main_effects ,
8383 feature_names = shap_out .feature_names )
8484 shap_outs [output_names [i ]][treatment_names [0 ]] = shap_out_new
85- return shap_outs
85+ # return plain dictionary so that erroneous accesses don't half work (see #708)
86+ return dict (shap_outs )
8687
8788
8889def _shap_explain_model_cate (cme_model , models , X , d_t , d_y , featurizer = None , feature_names = None ,
@@ -176,7 +177,8 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe
176177 else :
177178 shap_outs [output_names [0 ]][treatment_names [i ]] = shap_out
178179
179- return shap_outs
180+ # return plain dictionary so that erroneous accesses don't half work (see #708)
181+ return dict (shap_outs )
180182
181183
182184def _shap_explain_joint_linear_model_cate (model_final , X , d_t , d_y , fit_cate_intercept ,
@@ -258,7 +260,8 @@ def _shap_explain_joint_linear_model_cate(model_final, X, d_t, d_y, fit_cate_int
258260 feature_names = shap_out .feature_names )
259261 shap_outs [output_names [0 ]][treatment_names [i ]] = shap_out_new
260262
261- return shap_outs
263+ # return plain dictionary so that erroneous accesses don't half work (see #708)
264+ return dict (shap_outs )
262265
263266
264267def _shap_explain_multitask_model_cate (cme_model , multitask_model_cate , X , d_t , d_y , featurizer = None ,
@@ -352,7 +355,8 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t,
352355 shap_outs [output_names [j ]][treatment_names [i ]] = shap_out_new
353356 else :
354357 shap_outs [output_names [j ]][treatment_names [0 ]] = shap_out
355- return shap_outs
358+ # return plain dictionary so that erroneous accesses don't half work (see #708)
359+ return dict (shap_outs )
356360
357361
358362def _define_names (d_t , d_y , treatment_names , output_names , feature_names , input_names , featurizer ):
0 commit comments