Skip to content

Commit cec0486

Browse files
authored
Merge pull request #497 from probcomp/20230109-ztangent-fix_dist_dsl_grads
Fix `logpdf_grad` errors in `@dist` DSL.
2 parents 4af5d8d + 449c640 commit cec0486

File tree

4 files changed

+134
-27
lines changed

4 files changed

+134
-27
lines changed

src/modeling_library/dist_dsl/dist_dsl.jl

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,73 @@ all_indices(arg::SimpleArg) = [arg.i]
2626
all_indices(arg::TransformedArg) = vcat([all_indices(a) for a in arg.f_args]...)
2727

2828
# Evaluate user-facing args to concrete values passed to the base distribution
29-
eval_arg(x::Any, args) = x
30-
eval_arg(x::SimpleArg, args) = typecheck_arg(x, args[x.i])
31-
eval_arg(x::TransformedArg, args) =
32-
x.arg_passer(x.orig_f, [eval_arg(a, args) for a in x.f_args]...)
29+
eval_arg(base_arg::Any, args) = base_arg
30+
eval_arg(base_arg::SimpleArg, args) = typecheck_arg(base_arg, args[base_arg.i])
31+
eval_arg(base_arg::TransformedArg, args) =
32+
base_arg.arg_passer(base_arg.orig_f, (eval_arg(a, args) for a in base_arg.f_args)...)
33+
34+
# Evaluate gradients of base distribution args with respect to user-facing args
35+
function eval_arg_gradient(base_arg::Any, base_type::Type, args)
36+
grads = map(enumerate(args)) do (i, arg)
37+
if arg isa Real || arg isa AbstractArray && eltype(arg) <: Real
38+
zero(arg) # Base arg is always constant with respect to input args
39+
else
40+
nothing
41+
end
42+
end
43+
return grads
44+
end
45+
46+
function eval_arg_gradient(base_arg::SimpleArg{T}, base_type::Type, args) where {T}
47+
grads = map(enumerate(args)) do (i, arg)
48+
if arg isa Real # Base arg is either equal to or unaffected by input arg
49+
i == base_arg.i ? one(arg) : zero(arg)
50+
elseif arg isa AbstractArray && eltype(arg) <: Real
51+
N, V = length(arg), eltype(arg)
52+
i == base_arg.i ? Matrix{V}(LinearAlgebra.I, N, N) : zeros(V, N, N)
53+
else
54+
nothing
55+
end
56+
end
57+
return grads
58+
end
59+
60+
# Compute gradients when base arg is a scalar type
61+
function eval_arg_gradient(base_arg::TransformedArg, base_type::Type{<:Real}, args)
62+
splice_arg(arg, i) = [args[1:i-1]..., arg, args[i+1:end]...]
63+
per_arg_eval(arg, i) = eval_arg(base_arg, splice_arg(arg, i))
64+
grads = map(enumerate(args)) do (i, arg)
65+
if arg isa Real
66+
ReverseDiff.gradient(a -> per_arg_eval(a, i), [arg])[1]
67+
elseif arg isa AbstractArray && eltype(arg) <: Real
68+
ReverseDiff.gradient(a -> per_arg_eval(a, i), arg)
69+
else
70+
nothing
71+
end
72+
end
73+
return grads
74+
end
75+
76+
# Compute Jacobians when base arg is an array type
77+
function eval_arg_gradient(base_arg::TransformedArg, base_type::Type{<:AbstractArray{<:Real}}, args)
78+
splice_arg(arg, i) = [args[1:i-1]..., arg, args[i+1:end]...]
79+
per_arg_eval(arg, i) = eval_arg(base_arg, splice_arg(arg, i))
80+
grads = map(enumerate(args)) do (i, arg)
81+
if arg isa Real
82+
ReverseDiff.jacobian(a -> per_arg_eval(a, i), [arg])
83+
elseif arg isa AbstractArray && eltype(arg) <: Real
84+
ReverseDiff.jacobian(a -> per_arg_eval(a, i), arg)
85+
else
86+
nothing
87+
end
88+
end
89+
return grads
90+
end
3391

3492
# Type of SimpleArg must match arg, otherwise a MethodError will be thrown
35-
typecheck_arg(x::SimpleArg{T}, arg::T) where {T} = arg
93+
typecheck_arg(base_arg::SimpleArg{T}, arg::T) where {T} = arg
94+
typecheck_arg(base_arg::SimpleArg{T}, arg::ReverseDiff.TrackedReal{T}) where {T <: Real} = arg
95+
typecheck_arg(base_arg::SimpleArg{T}, arg::ReverseDiff.TrackedArray{V, D, N, T}) where {V, D, N, T} = arg
3696

3797
# DistWithArgs
3898
struct DistWithArgs{T}
@@ -71,22 +131,32 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T
71131
concrete_args = [eval_arg(arg, args) for arg in d.arglist]
72132
base_has_arg_grads = has_argument_grads(d.base)
73133
base_grads = logpdf_grad(d.base, x, concrete_args...)
74-
75-
base_arg_grads = [g for (i, g) in enumerate(base_grads[2:end])
76-
if base_has_arg_grads[i]]
77-
argvec = collect(args)
78-
eval_arg_grads = hcat([ReverseDiff.gradient(xs -> eval_arg(arg, xs), argvec)
79-
for (i, arg) in enumerate(d.arglist) if base_has_arg_grads[i]]...)
80-
81-
retval = [base_grads[1]]
82-
for i in 1:d.n_args
83-
if self_has_arg_grads[i]
84-
push!(retval, eval_arg_grads[i,:]' * base_arg_grads)
85-
else
86-
push!(retval, nothing)
134+
base_arg_grads = base_grads[2:end]
135+
136+
# Set gradient with respect to output
137+
self_output_grad = base_grads[1]
138+
139+
# Backpropagate gradients from base arguments to arguments
140+
self_arg_grads = [self_has_arg_grads[i] ? zero(arg) : nothing
141+
for (i, arg) in enumerate(args)]
142+
143+
for (i, base_arg) in enumerate(d.arglist)
144+
base_has_arg_grads[i] || continue
145+
base_grad = base_arg_grads[i]
146+
base_arg_type = typeof(concrete_args[i])
147+
eval_arg_grad = eval_arg_gradient(base_arg, base_arg_type, args)
148+
for (j, g) in enumerate(eval_arg_grad)
149+
(isnothing(g) || !self_has_arg_grads[j]) && continue
150+
if base_grad isa AbstractArray
151+
increment = reshape(g' * vec(base_grad), size(self_arg_grads[j]))
152+
else
153+
increment = g * base_grad
154+
end
155+
self_arg_grads[j] = self_arg_grads[j] .+ increment
87156
end
88157
end
89-
retval
158+
159+
return (self_output_grad, self_arg_grads...)
90160
end
91161

92162
function random(d::CompiledDistWithArgs{T}, args...)::T where T

src/modeling_library/dist_dsl/relabeled_distribution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ function logpdf(d::WithLabelArg{T, U}, x::T, collection, base_args...) where {T,
1919
end
2020

2121
function logpdf_grad(d::WithLabelArg{T, U}, x::T, collection, base_args...) where {T, U}
22-
base_arg_grads = fill(nothing, length(base_args))
22+
base_arg_grads = Vector{Any}(nothing, length(base_args))
2323

2424
for p in pairs(collection)
2525
(index, item) = (p.first, p.second)
2626
if item == x
2727
new_grads = logpdf_grad(d.base, index, base_args...)
28-
for (arg_idx, grad) in enumerate(new_grads)
28+
for (arg_idx, grad) in enumerate(new_grads[2:end])
2929
if base_arg_grads[arg_idx] === nothing
3030
base_arg_grads[arg_idx] = grad
3131
elseif grad !== nothing
@@ -73,13 +73,13 @@ function logpdf(d::RelabeledDistribution{T, U}, x::T, base_args...) where {T, U}
7373
end
7474

7575
function logpdf_grad(d::RelabeledDistribution{T, U}, x::T, base_args...) where {T, U}
76-
base_arg_grads = fill(nothing, length(base_args))
76+
base_arg_grads = Vector{Any}(nothing, length(base_args))
7777

7878
for p in pairs(d.collection)
7979
(index, item) = (p.first, p.second)
8080
if item == x
8181
new_grads = logpdf_grad(d.base, index, base_args...)
82-
for (arg_idx, grad) in enumerate(new_grads)
82+
for (arg_idx, grad) in enumerate(new_grads[2:end])
8383
if base_arg_grads[arg_idx] === nothing
8484
base_arg_grads[arg_idx] = grad
8585
elseif grad !== nothing

src/modeling_library/dist_dsl/transformed_distribution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ function logpdf_grad(d::TransformedDistribution{T, U}, x::T, args...) where {T,
4141

4242
if is_discrete(d.base) || !has_output_grad(d.base)
4343
# TODO: should this be nothing or 0?
44-
[base_grad[1], fill(nothing, d.nArgs)..., base_grad[2:end]...]
44+
return (base_grad[1], fill(nothing, d.nArgs)..., base_grad[2:end]...)
4545
else
4646
transformation_grad = d.backward_grad(x, args[1:d.nArgs]...)
4747
correction_grad = ReverseDiff.gradient(v -> logpdf_correction(d, v[1], v[2:end]), [x, args[1:d.nArgs]...])
4848
# TODO: Will this sort of thing work if the arguments w.r.t. which we are taking
4949
# gradients are themselves vector-valued?
5050
full_grad = (transformation_grad .* base_grad[1]) .+ correction_grad
51-
[full_grad..., base_grad[2:end]...]
51+
return (full_grad..., base_grad[2:end]...)
5252
end
5353
end
5454

@@ -62,8 +62,8 @@ end
6262

6363
function has_argument_grads(d::TransformedDistribution{T, U}) where {T, U}
6464
if is_discrete(d.base) || !has_output_grad(d.base)
65-
[fill(false, d.nArgs)..., has_argument_grads(d.base)...]
65+
(fill(false, d.nArgs)..., has_argument_grads(d.base)...)
6666
else
67-
[fill(true, d.nArgs)..., has_argument_grads(d.base)...]
67+
(fill(true, d.nArgs)..., has_argument_grads(d.base)...)
6868
end
6969
end

test/modeling_library/dist_dsl.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,51 @@
22
# Test transformed distributions
33
@dist f(x) = exp(normal(x, 0.001))
44
@test isapprox(1, f(0); atol = 5)
5+
@test isapprox(logpdf(f, 1., 0.), logpdf(normal, 0., 0., 0.001))
6+
7+
# Test gradients of transformed distributions
8+
@dist shifted_normal(mu, sigma) = normal(mu, sigma) + 1.
9+
@test logpdf(shifted_normal, 1., 0., 1.) == logpdf(normal, 0., 0., 1.)
10+
@test logpdf_grad(shifted_normal, 0., 0., 1.) == logpdf_grad(normal, -1., 0., 1.)
11+
12+
# Test gradients of transformed distributions with no parameters
13+
@dist shifted_std_normal() = normal(0., 1.) + 1.
14+
@test logpdf(shifted_std_normal, 1.) == logpdf(normal, 0., 0., 1.)
15+
@test logpdf_grad(shifted_std_normal, 0.) == (logpdf_grad(normal, -1., 0., 1.)[1],)
16+
17+
# Test gradients of multivariate distributions
18+
@dist vec_normal(mu, sigma) = broadcasted_normal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0))
19+
@test logpdf(vec_normal, zeros(2), zeros(2), ones(2)) ==
20+
logpdf(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2))
21+
transformed_grads = logpdf_grad(vec_normal, zeros(2), zeros(2), ones(2))
22+
orig_grads = logpdf_grad(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2))
23+
@test transformed_grads[1] == orig_grads[1]
24+
@test transformed_grads[2] == orig_grads[2]
25+
@test transformed_grads[3] == 2.0 * orig_grads[3]
26+
27+
# Test gradients of multivariate distributions with multi-dimensional arguments
28+
@dist transformed_mvnormal(mu, sigma) = mvnormal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0))
29+
@test logpdf(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2)) ==
30+
logpdf(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2))
31+
transformed_grads = logpdf_grad(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2))
32+
orig_grads = logpdf_grad(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2))
33+
@test transformed_grads[1] == orig_grads[1]
34+
@test transformed_grads[2] == orig_grads[2]
35+
@test transformed_grads[3] == 2.0 * orig_grads[3]
536

637
# Test relabeled distributions with labels provided as an Array
738
@dist labeled_cat(labels, probs) = labels[categorical(probs)]
839
@test labeled_cat([:a, :b], [0., 1.]) == :b
940
@test isapprox(logpdf(labeled_cat, :b, [:a, :b], [0.5, 0.5]), log(0.5))
41+
@test logpdf_grad(labeled_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...)
1042
@test logpdf(labeled_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf
1143

1244
# Test relabeled distributions with labels provided in a Dict
1345
dict = Dict(1 => :a, 2 => :b)
1446
@dist dict_cat(probs) = dict[categorical(probs)]
1547
@test dict_cat([0., 1.]) == :b
1648
@test isapprox(logpdf(dict_cat, :b, [0.5, 0.5]), log(0.5))
49+
@test logpdf_grad(dict_cat, :b, [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5])
1750
@test logpdf(dict_cat, :c, [0.5, 0.5]) == -Inf
1851

1952
# Test relabeled distributions with Enum labels
@@ -22,11 +55,13 @@
2255
@test enum_cat([0., 1.]) == orange
2356
@test isapprox(logpdf(enum_cat, orange, [0.5, 0.5]), log(0.5))
2457
@test logpdf(enum_cat, orange, [1.0]) == -Inf
58+
@test logpdf_grad(enum_cat, orange, [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5])
2559

2660
# Regression test for https://github.com/probcomp/Gen/issues/253
2761
@dist real_minus_uniform(a, b) = 1 - Gen.uniform(a, b)
2862
@test real_minus_uniform(1, 2) < 0
2963
@test logpdf(real_minus_uniform, -0.5, 1, 2) == 0.0
64+
@test logpdf_grad(real_minus_uniform, -0.5, 1, 2) == logpdf_grad(uniform, 1.5, 1, 2)
3065
end
3166

3267
# User-defined type for testing purposes
@@ -40,13 +75,15 @@ end
4075
@test symbol_cat([:a, :b], [0., 1.]) == :b
4176
@test_throws MethodError symbol_cat(["a", "b"], [0., 1.])
4277
@test logpdf(symbol_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf
78+
@test logpdf_grad(symbol_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...)
4379
@test_throws MethodError logpdf(symbol_cat, "c", [:a, :b], [0.5, 0.5])
4480

4581
# Test typed parameters
4682
@dist int_bounded_uniform(low::Int, high::Int) = uniform(low, high)
4783
@test 0.0 <= int_bounded_uniform(0, 1) <= 1
4884
@test_throws MethodError int_bounded_uniform(-0.5, 0.5)
4985
@test logpdf(int_bounded_uniform, 0.5, 0, 1) == 0
86+
@test logpdf_grad(int_bounded_uniform, 0.5, 0, 1) == logpdf_grad(uniform, 0.5, 0, 1)
5087
@test_throws MethodError logpdf(int_bounded_uniform, 0.0, -0.5, 0.5)
5188

5289
# Test relabeled distributions with user-defined types

0 commit comments

Comments
 (0)