-
Notifications
You must be signed in to change notification settings - Fork 309
Description
I noticed the changes in #1583 cause a slowdown in models with small datasets such as the MMM workflow example when using Numba / C backends (I didn't measure JAX as the JAX samplers will take gradients via JAX).
The speed regression in the numba backend (both in PyMC / nutpie) should be fixed by pymc-devs/pytensor#1378. I confirmed the model is faster than before for the 7 channels / ~188 days of data used the intro example. I also used freeze_data_and_dims
which is the default in the non-PyMC samplers. The gradient of the convolution wrt to the smallest input is much faster that way.
Fixing the speed regression on the C backend requires more effort, as neither Blockwise nor Convolve1d have a native C implementation right now.
The old batched_convolution was basically an unrolled convolution loop (with a native C-implementation in the C-backend). It's not surprising this can sometimes be faster for small convolutions (at the cost of higher compile time), but the worst case scenario can be much worse (as showed in the timinigs in the original PR)