Skip to content

Commit 9362af7

Browse files
committed
Fix logpdf_grad for transformed cont. dists.
1 parent bf65928 commit 9362af7

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/modeling_library/dist_dsl/dist_dsl.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T
7676
base_arg_grads = [g for (i, g) in enumerate(base_grads[2:end])
7777
if base_has_arg_grads[i]]
7878
argvec = collect(args)
79-
eval_arg_grads = hcat([ReverseDiff.gradient(xs -> eval_arg(arg, xs), argvec)
80-
for (i, arg) in enumerate(d.arglist) if base_has_arg_grads[i]]...)
79+
if !isempty(argvec)
80+
eval_arg_grads = [ReverseDiff.gradient(xs -> eval_arg(arg, xs), argvec) for
81+
(i, arg) in enumerate(d.arglist) if base_has_arg_grads[i]]
82+
eval_arg_grads = reduce(hcat, eval_arg_grads)
83+
end
8184

8285
retval = [base_grads[1]]
8386
for i in 1:d.n_args
@@ -87,7 +90,7 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T
8790
push!(retval, nothing)
8891
end
8992
end
90-
retval
93+
return Tuple(retval)
9194
end
9295

9396
function random(d::CompiledDistWithArgs{T}, args...)::T where T

src/modeling_library/dist_dsl/transformed_distribution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ function logpdf_grad(d::TransformedDistribution{T, U}, x::T, args...) where {T,
4141

4242
if is_discrete(d.base) || !has_output_grad(d.base)
4343
# TODO: should this be nothing or 0?
44-
[base_grad[1], fill(nothing, d.nArgs)..., base_grad[2:end]...]
44+
return (base_grad[1], fill(nothing, d.nArgs)..., base_grad[2:end]...)
4545
else
4646
transformation_grad = d.backward_grad(x, args[1:d.nArgs]...)
4747
correction_grad = ReverseDiff.gradient(v -> logpdf_correction(d, v[1], v[2:end]), [x, args[1:d.nArgs]...])
4848
# TODO: Will this sort of thing work if the arguments w.r.t. which we are taking
4949
# gradients are themselves vector-valued?
5050
full_grad = (transformation_grad .* base_grad[1]) .+ correction_grad
51-
[full_grad..., base_grad[2:end]...]
51+
return (full_grad..., base_grad[2:end]...)
5252
end
5353
end
5454

@@ -62,8 +62,8 @@ end
6262

6363
function has_argument_grads(d::TransformedDistribution{T, U}) where {T, U}
6464
if is_discrete(d.base) || !has_output_grad(d.base)
65-
[fill(false, d.nArgs)..., has_argument_grads(d.base)...]
65+
(fill(false, d.nArgs)..., has_argument_grads(d.base)...)
6666
else
67-
[fill(true, d.nArgs)..., has_argument_grads(d.base)...]
67+
(fill(true, d.nArgs)..., has_argument_grads(d.base)...)
6868
end
6969
end

0 commit comments

Comments
 (0)