-
Notifications
You must be signed in to change notification settings - Fork 36
Description
#555 introduced DynamicPPL.ReshapeTransform
, which is very nice, but there's what seems to be a bug in ReverseDiff.jl which causes it to fail when ReshapeTransform is composed with a broadcasted function.
I reported the upstream bug at JuliaDiff/ReverseDiff.jl#265. In the context of DynamicPPL, this occurs when we have something like the following:
using DynamicPPL: invlink_transform, ReshapeTransform
using ReverseDiff
f(x) = invlink_transform(InverseGamma(2, 3))
g(x) = ReshapeTransform(())(x)
h = f ∘ g
ReverseDiff.gradient(h, [1.0])
I suspect we should be able to change the implementation of ReshapeTransform
though to try to circumvent this. I don't actually know all the possible shapes of stuff ReshapeTransform
handles and whether different input/output shapes would give different ReverseDiff errors. However, I dug into a couple of the failing tests in Turing.jl, and it seems that both of them stem from ReshapeTransform being given singleton arrays (e.g. [1.0]
above). Furthermore, the error message observed in all the other failing tests is the same (although I didn't verify that they ultimately stem from singleton arrays). So I think we could special-case this behaviour to keep ReverseDiff on our side.