@@ -290,7 +290,6 @@ def plot_pdp(
290
290
var_idx : Optional [list [int ]] = None ,
291
291
var_discrete : Optional [list [int ]] = None ,
292
292
func : Optional [Callable ] = None ,
293
- softmax_link : Optional [bool ] = False ,
294
293
samples : int = 200 ,
295
294
ref_line : bool = True ,
296
295
random_seed : Optional [int ] = None ,
@@ -331,9 +330,6 @@ def plot_pdp(
331
330
List of the indices of the covariate treated as discrete.
332
331
func : Optional[Callable], by default None.
333
332
Arbitrary function to apply to the predictions. Defaults to the identity function.
334
- softmax_link: Optional[bool] = False,
335
- If True the predictions are transformed using the softmax function. Only works when
336
- likelihood is categorical. Defaults to False.
337
333
samples : int
338
334
Number of posterior samples used in the predictions. Defaults to 200
339
335
ref_line : bool
@@ -400,17 +396,18 @@ def identity(x):
400
396
p_d = _sample_posterior (
401
397
all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
402
398
)
403
- if softmax_link is True :
404
- from scipy .special import softmax
405
-
406
- # categories are the last dimension
407
- p_d = softmax (p_d , axis = - 1 )
399
+ # need to apply func to full array and to last dimension if it's softmax
400
+ if func .__name__ == "softmax" :
401
+ # categories are always the last dimension
402
+ # for some reason, mypy thinks that func can be identity,
403
+ # which doesn't have the axis argument
404
+ p_d = func (p_d , axis = - 1 ) # type: ignore[call-arg]
408
405
409
406
with warnings .catch_warnings ():
410
407
warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
411
408
new_x = fake_X [:, var ]
412
409
for s_i in range (shape ):
413
- p_di = func (p_d [:, :, s_i ])
410
+ p_di = p_d [:, :, s_i ] if func . __name__ == "softmax" else func (p_d [:, :, s_i ])
414
411
null_pd .append (p_di .mean ())
415
412
if var in var_discrete :
416
413
_ , idx_uni = np .unique (new_x , return_index = True )
0 commit comments