Skip to content

Commit 88847b6

Browse files
committed
Use func for softmax
1 parent 733e66b commit 88847b6

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

pymc_bart/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def plot_pdp(
290290
var_idx: Optional[list[int]] = None,
291291
var_discrete: Optional[list[int]] = None,
292292
func: Optional[Callable] = None,
293-
softmax_link: Optional[bool] = False,
294293
samples: int = 200,
295294
ref_line: bool = True,
296295
random_seed: Optional[int] = None,
@@ -331,9 +330,6 @@ def plot_pdp(
331330
List of the indices of the covariate treated as discrete.
332331
func : Optional[Callable], by default None.
333332
Arbitrary function to apply to the predictions. Defaults to the identity function.
334-
softmax_link: Optional[bool] = False,
335-
If True the predictions are transformed using the softmax function. Only works when
336-
likelihood is categorical. Defaults to False.
337333
samples : int
338334
Number of posterior samples used in the predictions. Defaults to 200
339335
ref_line : bool
@@ -400,17 +396,18 @@ def identity(x):
400396
p_d = _sample_posterior(
401397
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
402398
)
403-
if softmax_link is True:
404-
from scipy.special import softmax
405-
406-
# categories are the last dimension
407-
p_d = softmax(p_d, axis=-1)
399+
# need to apply func to full array and to last dimension if it's softmax
400+
if func.__name__ == "softmax":
401+
# categories are always the last dimension
402+
# for some reason, mypy thinks that func can be identity,
403+
# which doesn't have the axis argument
404+
p_d = func(p_d, axis=-1) # type: ignore[call-arg]
408405

409406
with warnings.catch_warnings():
410407
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
411408
new_x = fake_X[:, var]
412409
for s_i in range(shape):
413-
p_di = func(p_d[:, :, s_i])
410+
p_di = p_d[:, :, s_i] if func.__name__ == "softmax" else func(p_d[:, :, s_i])
414411
null_pd.append(p_di.mean())
415412
if var in var_discrete:
416413
_, idx_uni = np.unique(new_x, return_index=True)

0 commit comments

Comments
 (0)