Skip to content

Commit 444f73e

Browse files
committed
handle func upstream
1 parent 88847b6 commit 444f73e

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

pymc_bart/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,21 +393,17 @@ def identity(x):
393393
for var in range(len(var_idx)):
394394
excluded = indices[:]
395395
excluded.remove(var)
396-
p_d = _sample_posterior(
397-
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
396+
p_d = func(
397+
_sample_posterior(
398+
all_trees, X=fake_X, rng=rng, size=samples, excluded=excluded, shape=shape
399+
)
398400
)
399-
# need to apply func to full array and to last dimension if it's softmax
400-
if func.__name__ == "softmax":
401-
# categories are always the last dimension
402-
# for some reason, mypy thinks that func can be identity,
403-
# which doesn't have the axis argument
404-
p_d = func(p_d, axis=-1) # type: ignore[call-arg]
405401

406402
with warnings.catch_warnings():
407403
warnings.filterwarnings("ignore", message="hdi currently interprets 2d data")
408404
new_x = fake_X[:, var]
409405
for s_i in range(shape):
410-
p_di = p_d[:, :, s_i] if func.__name__ == "softmax" else func(p_d[:, :, s_i])
406+
p_di = p_d[:, :, s_i]
411407
null_pd.append(p_di.mean())
412408
if var in var_discrete:
413409
_, idx_uni = np.unique(new_x, return_index=True)

0 commit comments

Comments
 (0)