@@ -290,7 +290,6 @@ def plot_pdp(
290290 var_idx : Optional [list [int ]] = None ,
291291 var_discrete : Optional [list [int ]] = None ,
292292 func : Optional [Callable ] = None ,
293- softmax_link : Optional [bool ] = False ,
294293 samples : int = 200 ,
295294 ref_line : bool = True ,
296295 random_seed : Optional [int ] = None ,
@@ -331,9 +330,6 @@ def plot_pdp(
331330 List of the indices of the covariate treated as discrete.
332331 func : Optional[Callable], by default None.
333332 Arbitrary function to apply to the predictions. Defaults to the identity function.
334- softmax_link: Optional[bool] = False,
335- If True the predictions are transformed using the softmax function. Only works when
336- likelihood is categorical. Defaults to False.
337333 samples : int
338334 Number of posterior samples used in the predictions. Defaults to 200
339335 ref_line : bool
@@ -400,17 +396,18 @@ def identity(x):
400396 p_d = _sample_posterior (
401397 all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
402398 )
403- if softmax_link is True :
404- from scipy .special import softmax
405-
406- # categories are the last dimension
407- p_d = softmax (p_d , axis = - 1 )
399+ # need to apply func to full array and to last dimension if it's softmax
400+ if func .__name__ == "softmax" :
401+ # categories are always the last dimension
402+ # for some reason, mypy thinks that func can be identity,
403+ # which doesn't have the axis argument
404+ p_d = func (p_d , axis = - 1 ) # type: ignore[call-arg]
408405
409406 with warnings .catch_warnings ():
410407 warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
411408 new_x = fake_X [:, var ]
412409 for s_i in range (shape ):
413- p_di = func (p_d [:, :, s_i ])
410+ p_di = p_d [:, :, s_i ] if func . __name__ == "softmax" else func (p_d [:, :, s_i ])
414411 null_pd .append (p_di .mean ())
415412 if var in var_discrete :
416413 _ , idx_uni = np .unique (new_x , return_index = True )
0 commit comments