Skip to content

Commit 2be60af

Browse files
author
Juan Orduz
authored
Fix ICE plot when there is a discrete variable (#107)
* fix tata plot shape * add test case for discrete variables
1 parent 592c6b2 commit 2be60af

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def plot_ice(
289289
p_di = func(p_d[:, :, s_i])
290290
if var in var_discrete:
291291
axes[count].plot(new_x, p_di.mean(0), "o", color=color_mean)
292-
axes[count].plot(new_x, p_di, ".", color=color, alpha=alpha)
292+
axes[count].plot(new_x, p_di.T, ".", color=color, alpha=alpha)
293293
else:
294294
if smooth:
295295
x_data, y_data = _smooth_mean(new_x, p_di, "ice", smooth_kwargs)

tests/test_bart.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def test_sample_posterior(self):
160160
{"instances": 2},
161161
{"var_idx": [0], "smooth": False, "color": "k"},
162162
{"grid": (1, 2), "sharey": "none", "alpha": 1},
163+
{"var_discrete": [0]}
163164
],
164165
)
165166
def test_ice(self, kwargs):
@@ -177,6 +178,7 @@ def test_ice(self, kwargs):
177178
},
178179
{"var_idx": [0], "smooth": False, "color": "k"},
179180
{"grid": (1, 2), "sharey": "none", "alpha": 1},
181+
{"var_discrete": [0]}
180182
],
181183
)
182184
def test_pdp(self, kwargs):

0 commit comments

Comments
 (0)