Skip to content

BUG: Using dims when fitting with Numba forces it to run in object mode #7802

@zaxtax

Description

@zaxtax

Describe the issue:

When fitting a minibatched VI model, if I use dims in the observed RV, Numba only runs in object model. These warnings and slowdown goes away if the y = ... line shown below drops the use of dims

Reproduceable code example:

import pymc as pm
import numpy as np

data = np.random.normal(size=100_000)

with pm.Model() as model:
     d = pm.Data("data", data)
     mb = pm.Minibatch(d, batch_size=100)
     model.add_coord("mb_dim", range(100))
     x = pm.Normal("x", 0, 1)
     y = pm.Normal("y", x, observed=mb, total_size=len(data), dims="mb_dim")

with model:
     trace = pm.fit(100_000, compile_kwargs={"mode": "NUMBA"})

Error message:

site-packages/pytensor/link/numba/dispatch/basic.py:288: UserWarning: Numba will use object mode to run MinibatchRandomVariable's perform method

PyMC version information:

Latest released versions of pytensor (2.31) and pymc (5.22)

I installed it using pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions