Skip to content

Commit 6973cfd

Browse files
committed
Fix/improve random and logpdf
1 parent b61eee7 commit 6973cfd

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/modeling_library/product.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ function extract_args_for_component(dist::ProductDistribution, component_args_fl
8080
return component_args_flat[start_arg:start_arg+n-1]
8181
end
8282

83-
Gen.random(dist::ProductDistribution, component_args_flat...) =
84-
[random(dist.distributions[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K]
83+
Gen.random(dist::ProductDistribution, args...) =
84+
Tuple(random(d, extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))
8585

86-
Gen.logpdf(dist::ProductDistribution, x, component_args_flat...) =
87-
sum(Gen.logpdf(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K)
86+
Gen.logpdf(dist::ProductDistribution, x, args...) =
87+
sum(Gen.logpdf(d, x[k], extract_args_for_component(dist, args, k)...) for (k, d) in enumerate(dist.distributions))
8888

8989
function Gen.logpdf_grad(dist::ProductDistribution, x, component_args_flat...)
9090
logpdf_grads = [Gen.logpdf_grad(dist.distributions[k], x[k], extract_args_for_component(dist, component_args_flat, k)...) for k in 1:dist.K]

0 commit comments

Comments
 (0)