-
Notifications
You must be signed in to change notification settings - Fork 39
Description
This is an old idea, and not even my own idea (Tor and Niko have both talked about it), but I don't think it's made it to a GitHub issue yet.
The tl;dr is to overload or reimplement logpdf(dist, x) in a way that is more performant than Distributions' own implementation. This could function as a way to sidestep inefficiencies in Distributions, but we could also optimise for the probabilistic programming setting by dropping constant additive terms. This is analogous to the propto argument in (Bridge)Stan (https://roualdes.us/bridgestan/latest/languages/julia.html#BridgeStan.log_density).
Note that for samplers like NUTS we could completely run the sampling itself with these simplified logpdfs, and recalculate the 'true' logpdfs (i.e., the result stored in the chain) using the proper logpdfs. That's possible because at the end of every iteration of NUTS, we recalculate the logpdfs anyway (inside ParamsWithStats) -- we do this in order to drop the Jacobian terms (cf. TuringLang/Turing.jl#2617). So, this could in fact be completely hidden from the (Turing) user. I think this is a very strong argument for at least trying: it could be something that's completely internal, it wouldn't require us to make any API changes to DynamicPPL.
I'm not certain how much performance we could squeeze out of this, but I think it would be worth a try. It would certainly be very cheap to define
propto_logpdf(d::Distribution, x) = logpdf(d, x)and then override distributions one by one. In fact, I think Claude would be very capable at this part of it.
More likely the tricky engineering work will be focused on modifying accumulators and LogDensityFunction to use the propto terms, and then threading it through the Turing machinery.
I'm also not certain that DPPL is the right place for this; I think there's a strong argument for it belonging in a standalone library.