Skip to content

Commit 957b0ac

Browse files
committed
Add softmax option to plot_pdp
1 parent d41e239 commit 957b0ac

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

pymc_bart/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def plot_pdp(
290290
var_idx: Optional[list[int]] = None,
291291
var_discrete: Optional[list[int]] = None,
292292
func: Optional[Callable] = None,
293+
softmax_link: Optional[bool] = False,
293294
samples: int = 200,
294295
ref_line: bool = True,
295296
random_seed: Optional[int] = None,
@@ -330,6 +331,9 @@ def plot_pdp(
330331
List of the indices of the covariate treated as discrete.
331332
func : Optional[Callable], by default None.
332333
Arbitrary function to apply to the predictions. Defaults to the identity function.
334+
softmax_link: Optional[bool] = False,
335+
If True the predictions are transformed using the softmax function. Only works when
336+
likelihood is categorical. Defaults to False.
333337
samples : int
334338
Number of posterior samples used in the predictions. Defaults to 200
335339
ref_line : bool
@@ -396,6 +400,12 @@ def identity(x):
396400
p_d = _sample_posterior(
397401
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
398402
)
403+
if softmax_link is True:
404+
from scipy.special import softmax
405+
406+
# categories are the last dimension
407+
p_d = softmax(p_d, axis=-1)
408+
399409
with warnings.catch_warnings():
400410
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
401411
new_x = fake_X[:, var]

0 commit comments

Comments
 (0)