@@ -41,14 +41,14 @@ function logpdf_grad(d::TransformedDistribution{T, U}, x::T, args...) where {T,
4141
4242 if is_discrete (d. base) || ! has_output_grad (d. base)
4343 # TODO : should this be nothing or 0?
44- [ base_grad[1 ], fill (nothing , d. nArgs)... , base_grad[2 : end ]. .. ]
44+ return ( base_grad[1 ], fill (nothing , d. nArgs)... , base_grad[2 : end ]. .. )
4545 else
4646 transformation_grad = d. backward_grad (x, args[1 : d. nArgs]. .. )
4747 correction_grad = ReverseDiff. gradient (v -> logpdf_correction (d, v[1 ], v[2 : end ]), [x, args[1 : d. nArgs]. .. ])
4848 # TODO : Will this sort of thing work if the arguments w.r.t. which we are taking
4949 # gradients are themselves vector-valued?
5050 full_grad = (transformation_grad .* base_grad[1 ]) .+ correction_grad
51- [ full_grad... , base_grad[2 : end ]. .. ]
51+ return ( full_grad... , base_grad[2 : end ]. .. )
5252 end
5353end
5454
6262
6363function has_argument_grads (d:: TransformedDistribution{T, U} ) where {T, U}
6464 if is_discrete (d. base) || ! has_output_grad (d. base)
65- [ fill (false , d. nArgs)... , has_argument_grads (d. base)... ]
65+ ( fill (false , d. nArgs)... , has_argument_grads (d. base)... )
6666 else
67- [ fill (true , d. nArgs)... , has_argument_grads (d. base)... ]
67+ ( fill (true , d. nArgs)... , has_argument_grads (d. base)... )
6868 end
6969end
0 commit comments