@@ -290,6 +290,7 @@ def plot_pdp(
290290 var_idx : Optional [list [int ]] = None ,
291291 var_discrete : Optional [list [int ]] = None ,
292292 func : Optional [Callable ] = None ,
293+ softmax_link : Optional [bool ] = False ,
293294 samples : int = 200 ,
294295 ref_line : bool = True ,
295296 random_seed : Optional [int ] = None ,
@@ -330,6 +331,9 @@ def plot_pdp(
330331 List of the indices of the covariate treated as discrete.
331332 func : Optional[Callable], by default None.
332333 Arbitrary function to apply to the predictions. Defaults to the identity function.
334+ softmax_link: Optional[bool] = False,
335+ If True the predictions are transformed using the softmax function. Only works when
336+ likelihood is categorical. Defaults to False.
333337 samples : int
334338 Number of posterior samples used in the predictions. Defaults to 200
335339 ref_line : bool
@@ -396,6 +400,12 @@ def identity(x):
396400 p_d = _sample_posterior (
397401 all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
398402 )
403+ if softmax_link is True :
404+ from scipy .special import softmax
405+
406+ # categories are the last dimension
407+ p_d = softmax (p_d , axis = - 1 )
408+
399409 with warnings .catch_warnings ():
400410 warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
401411 new_x = fake_X [:, var ]
0 commit comments