@@ -290,6 +290,7 @@ def plot_pdp(
290
290
var_idx : Optional [list [int ]] = None ,
291
291
var_discrete : Optional [list [int ]] = None ,
292
292
func : Optional [Callable ] = None ,
293
+ softmax_link : Optional [bool ] = False ,
293
294
samples : int = 200 ,
294
295
ref_line : bool = True ,
295
296
random_seed : Optional [int ] = None ,
@@ -330,6 +331,9 @@ def plot_pdp(
330
331
List of the indices of the covariate treated as discrete.
331
332
func : Optional[Callable], by default None.
332
333
Arbitrary function to apply to the predictions. Defaults to the identity function.
334
+ softmax_link: Optional[bool] = False,
335
+ If True the predictions are transformed using the softmax function. Only works when
336
+ likelihood is categorical. Defaults to False.
333
337
samples : int
334
338
Number of posterior samples used in the predictions. Defaults to 200
335
339
ref_line : bool
@@ -396,6 +400,12 @@ def identity(x):
396
400
p_d = _sample_posterior (
397
401
all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
398
402
)
403
+ if softmax_link is True :
404
+ from scipy .special import softmax
405
+
406
+ # categories are the last dimension
407
+ p_d = softmax (p_d , axis = - 1 )
408
+
399
409
with warnings .catch_warnings ():
400
410
warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
401
411
new_x = fake_X [:, var ]
0 commit comments