-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Describe the issue:
When fitting a minibatched VI model, if I use dims in the observed RV, Numba only runs in object model. These warnings and slowdown goes away if the y = ... line shown below drops the use of dims
Reproduceable code example:
import pymc as pm
import numpy as np
data = np.random.normal(size=100_000)
with pm.Model() as model:
d = pm.Data("data", data)
mb = pm.Minibatch(d, batch_size=100)
model.add_coord("mb_dim", range(100))
x = pm.Normal("x", 0, 1)
y = pm.Normal("y", x, observed=mb, total_size=len(data), dims="mb_dim")
with model:
trace = pm.fit(100_000, compile_kwargs={"mode": "NUMBA"})Error message:
site-packages/pytensor/link/numba/dispatch/basic.py:288: UserWarning: Numba will use object mode to run MinibatchRandomVariable's perform methodPyMC version information:
Latest released versions of pytensor (2.31) and pymc (5.22)
I installed it using pip