Skip to content

Commit 2b62bd0

Browse files
committed
Rewrite logpdf_grad
1 parent 6973cfd commit 2b62bd0

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

src/modeling_library/product.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,15 @@ Gen.random(dist::ProductDistribution, args...) =
8686
Gen.logpdf(dist::ProductDistribution, x, args...) =
8787
sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))
8888

89-
function Gen.logpdf_grad(dist::ProductDistribution, x, component_args_flat...)
90-
logpdf_grads = [Gen.logpdf_grad(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K]
91-
x_grad = if dist.has_output_grad
92-
tuple((grads[1] for grads in logpdf_grads)...)
93-
else
94-
nothing
89+
function Gen.logpdf_grad(dist::ProductDistribution, x, args...)
90+
x_grad = ()
91+
arg_grads = ()
92+
for (k, d) in enumerate(dist.distributions)
93+
grads = Gen.logpdf_grad(d, x[k], extract_args_for_component(dist, args, k)...)
94+
x_grad = (x_grad..., grads[1])
95+
arg_grads = (arg_grads..., grads[2:end]...)
9596
end
96-
arg_grads = vcat((collect(grads[2:end]) for grads in logpdf_grads)...)
97+
x_grad = dist.has_output_grad ? x_grad : nothing
9798
return (x_grad, arg_grads...)
9899
end
99100

0 commit comments

Comments
 (0)