Skip to content

Commit 449c640

Browse files
committed
Generalize backprop of base arg grads, add tests.
1 parent 9362af7 commit 449c640

File tree

3 files changed

+119
-33
lines changed

3 files changed

+119
-33
lines changed

src/modeling_library/dist_dsl/dist_dsl.jl

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,73 @@ all_indices(arg::SimpleArg) = [arg.i]
2626
all_indices(arg::TransformedArg) = vcat([all_indices(a) for a in arg.f_args]...)
2727

2828
# Evaluate user-facing args to concrete values passed to the base distribution
29-
eval_arg(x::Any, args) = x
30-
eval_arg(x::SimpleArg, args) = typecheck_arg(x, args[x.i])
31-
eval_arg(x::TransformedArg, args) =
32-
x.arg_passer(x.orig_f, [eval_arg(a, args) for a in x.f_args]...)
29+
eval_arg(base_arg::Any, args) = base_arg
30+
eval_arg(base_arg::SimpleArg, args) = typecheck_arg(base_arg, args[base_arg.i])
31+
eval_arg(base_arg::TransformedArg, args) =
32+
base_arg.arg_passer(base_arg.orig_f, (eval_arg(a, args) for a in base_arg.f_args)...)
33+
34+
# Evaluate gradients of base distribution args with respect to user-facing args
35+
function eval_arg_gradient(base_arg::Any, base_type::Type, args)
36+
grads = map(enumerate(args)) do (i, arg)
37+
if arg isa Real || arg isa AbstractArray && eltype(arg) <: Real
38+
zero(arg) # Base arg is always constant with respect to input args
39+
else
40+
nothing
41+
end
42+
end
43+
return grads
44+
end
45+
46+
function eval_arg_gradient(base_arg::SimpleArg{T}, base_type::Type, args) where {T}
47+
grads = map(enumerate(args)) do (i, arg)
48+
if arg isa Real # Base arg is either equal to or unaffected by input arg
49+
i == base_arg.i ? one(arg) : zero(arg)
50+
elseif arg isa AbstractArray && eltype(arg) <: Real
51+
N, V = length(arg), eltype(arg)
52+
i == base_arg.i ? Matrix{V}(LinearAlgebra.I, N, N) : zeros(V, N, N)
53+
else
54+
nothing
55+
end
56+
end
57+
return grads
58+
end
59+
60+
# Compute gradients when base arg is a scalar type
61+
function eval_arg_gradient(base_arg::TransformedArg, base_type::Type{<:Real}, args)
62+
splice_arg(arg, i) = [args[1:i-1]..., arg, args[i+1:end]...]
63+
per_arg_eval(arg, i) = eval_arg(base_arg, splice_arg(arg, i))
64+
grads = map(enumerate(args)) do (i, arg)
65+
if arg isa Real
66+
ReverseDiff.gradient(a -> per_arg_eval(a, i), [arg])[1]
67+
elseif arg isa AbstractArray && eltype(arg) <: Real
68+
ReverseDiff.gradient(a -> per_arg_eval(a, i), arg)
69+
else
70+
nothing
71+
end
72+
end
73+
return grads
74+
end
75+
76+
# Compute Jacobians when base arg is an array type
77+
function eval_arg_gradient(base_arg::TransformedArg, base_type::Type{<:AbstractArray{<:Real}}, args)
78+
splice_arg(arg, i) = [args[1:i-1]..., arg, args[i+1:end]...]
79+
per_arg_eval(arg, i) = eval_arg(base_arg, splice_arg(arg, i))
80+
grads = map(enumerate(args)) do (i, arg)
81+
if arg isa Real
82+
ReverseDiff.jacobian(a -> per_arg_eval(a, i), [arg])
83+
elseif arg isa AbstractArray && eltype(arg) <: Real
84+
ReverseDiff.jacobian(a -> per_arg_eval(a, i), arg)
85+
else
86+
nothing
87+
end
88+
end
89+
return grads
90+
end
3391

3492
# Type of SimpleArg must match arg, otherwise a MethodError will be thrown
35-
typecheck_arg(x::SimpleArg{T}, arg::T) where {T} = arg
36-
typecheck_arg(x::SimpleArg{T}, arg::ReverseDiff.TrackedReal{T}) where {T <: Real} = arg
93+
typecheck_arg(base_arg::SimpleArg{T}, arg::T) where {T} = arg
94+
typecheck_arg(base_arg::SimpleArg{T}, arg::ReverseDiff.TrackedReal{T}) where {T <: Real} = arg
95+
typecheck_arg(base_arg::SimpleArg{T}, arg::ReverseDiff.TrackedArray{V, D, N, T}) where {V, D, N, T} = arg
3796

3897
# DistWithArgs
3998
struct DistWithArgs{T}
@@ -72,25 +131,32 @@ function logpdf_grad(d::CompiledDistWithArgs{T}, x::T, args...) where T
72131
concrete_args = [eval_arg(arg, args) for arg in d.arglist]
73132
base_has_arg_grads = has_argument_grads(d.base)
74133
base_grads = logpdf_grad(d.base, x, concrete_args...)
75-
76-
base_arg_grads = [g for (i, g) in enumerate(base_grads[2:end])
77-
if base_has_arg_grads[i]]
78-
argvec = collect(args)
79-
if !isempty(argvec)
80-
eval_arg_grads = [ReverseDiff.gradient(xs -> eval_arg(arg, xs), argvec) for
81-
(i, arg) in enumerate(d.arglist) if base_has_arg_grads[i]]
82-
eval_arg_grads = reduce(hcat, eval_arg_grads)
83-
end
84-
85-
retval = [base_grads[1]]
86-
for i in 1:d.n_args
87-
if self_has_arg_grads[i]
88-
push!(retval, eval_arg_grads[i,:]' * base_arg_grads)
89-
else
90-
push!(retval, nothing)
134+
base_arg_grads = base_grads[2:end]
135+
136+
# Set gradient with respect to output
137+
self_output_grad = base_grads[1]
138+
139+
# Backpropagate gradients from base arguments to arguments
140+
self_arg_grads = [self_has_arg_grads[i] ? zero(arg) : nothing
141+
for (i, arg) in enumerate(args)]
142+
143+
for (i, base_arg) in enumerate(d.arglist)
144+
base_has_arg_grads[i] || continue
145+
base_grad = base_arg_grads[i]
146+
base_arg_type = typeof(concrete_args[i])
147+
eval_arg_grad = eval_arg_gradient(base_arg, base_arg_type, args)
148+
for (j, g) in enumerate(eval_arg_grad)
149+
(isnothing(g) || !self_has_arg_grads[j]) && continue
150+
if base_grad isa AbstractArray
151+
increment = reshape(g' * vec(base_grad), size(self_arg_grads[j]))
152+
else
153+
increment = g * base_grad
154+
end
155+
self_arg_grads[j] = self_arg_grads[j] .+ increment
91156
end
92157
end
93-
return Tuple(retval)
158+
159+
return (self_output_grad, self_arg_grads...)
94160
end
95161

96162
function random(d::CompiledDistWithArgs{T}, args...)::T where T

src/modeling_library/dist_dsl/relabeled_distribution.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ function logpdf(d::WithLabelArg{T, U}, x::T, collection, base_args...) where {T,
1919
end
2020

2121
function logpdf_grad(d::WithLabelArg{T, U}, x::T, collection, base_args...) where {T, U}
22-
base_arg_grads = fill(nothing, length(base_args))
22+
base_arg_grads = Vector{Any}(nothing, length(base_args))
2323

2424
for p in pairs(collection)
2525
(index, item) = (p.first, p.second)
2626
if item == x
2727
new_grads = logpdf_grad(d.base, index, base_args...)
28-
for (arg_idx, grad) in enumerate(new_grads)
28+
for (arg_idx, grad) in enumerate(new_grads[2:end])
2929
if base_arg_grads[arg_idx] === nothing
3030
base_arg_grads[arg_idx] = grad
3131
elseif grad !== nothing
@@ -73,13 +73,13 @@ function logpdf(d::RelabeledDistribution{T, U}, x::T, base_args...) where {T, U}
7373
end
7474

7575
function logpdf_grad(d::RelabeledDistribution{T, U}, x::T, base_args...) where {T, U}
76-
base_arg_grads = fill(nothing, length(base_args))
76+
base_arg_grads = Vector{Any}(nothing, length(base_args))
7777

7878
for p in pairs(d.collection)
7979
(index, item) = (p.first, p.second)
8080
if item == x
8181
new_grads = logpdf_grad(d.base, index, base_args...)
82-
for (arg_idx, grad) in enumerate(new_grads)
82+
for (arg_idx, grad) in enumerate(new_grads[2:end])
8383
if base_arg_grads[arg_idx] === nothing
8484
base_arg_grads[arg_idx] = grad
8585
elseif grad !== nothing

test/modeling_library/dist_dsl.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,40 @@
55
@test isapprox(logpdf(f, 1., 0.), logpdf(normal, 0., 0., 0.001))
66

77
# Test gradients of transformed distributions
8-
@dist shifted_normal(mu, sigma) = Gen.normal(mu, sigma) + 1.
9-
@test isapprox(logpdf(shifted_normal, 1., 0., 1.), logpdf(normal, 0., 0., 1.))
8+
@dist shifted_normal(mu, sigma) = normal(mu, sigma) + 1.
9+
@test logpdf(shifted_normal, 1., 0., 1.) == logpdf(normal, 0., 0., 1.)
1010
@test logpdf_grad(shifted_normal, 0., 0., 1.) == logpdf_grad(normal, -1., 0., 1.)
1111

1212
# Test gradients of transformed distributions with no parameters
13-
@dist shifted_std_normal() = Gen.normal(0., 1.) + 1.
14-
@test isapprox(logpdf(shifted_std_normal, 1.), logpdf(normal, 0., 0., 1.))
13+
@dist shifted_std_normal() = normal(0., 1.) + 1.
14+
@test logpdf(shifted_std_normal, 1.) == logpdf(normal, 0., 0., 1.)
1515
@test logpdf_grad(shifted_std_normal, 0.) == (logpdf_grad(normal, -1., 0., 1.)[1],)
1616

17+
# Test gradients of multivariate distributions
18+
@dist vec_normal(mu, sigma) = broadcasted_normal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0))
19+
@test logpdf(vec_normal, zeros(2), zeros(2), ones(2)) ==
20+
logpdf(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2))
21+
transformed_grads = logpdf_grad(vec_normal, zeros(2), zeros(2), ones(2))
22+
orig_grads = logpdf_grad(broadcasted_normal, zeros(2), ones(2), 2 .* ones(2))
23+
@test transformed_grads[1] == orig_grads[1]
24+
@test transformed_grads[2] == orig_grads[2]
25+
@test transformed_grads[3] == 2.0 * orig_grads[3]
26+
27+
# Test gradients of multivariate distributions with multi-dimensional arguments
28+
@dist transformed_mvnormal(mu, sigma) = mvnormal(broadcast(+, mu, 1.0), broadcast(*, sigma, 2.0))
29+
@test logpdf(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2)) ==
30+
logpdf(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2))
31+
transformed_grads = logpdf_grad(transformed_mvnormal, zeros(2), zeros(2), ones(2, 2))
32+
orig_grads = logpdf_grad(mvnormal, zeros(2), ones(2), 2 .* ones(2, 2))
33+
@test transformed_grads[1] == orig_grads[1]
34+
@test transformed_grads[2] == orig_grads[2]
35+
@test transformed_grads[3] == 2.0 * orig_grads[3]
36+
1737
# Test relabeled distributions with labels provided as an Array
1838
@dist labeled_cat(labels, probs) = labels[categorical(probs)]
1939
@test labeled_cat([:a, :b], [0., 1.]) == :b
2040
@test isapprox(logpdf(labeled_cat, :b, [:a, :b], [0.5, 0.5]), log(0.5))
21-
@test logpdf_grad(labeled_cat, :b, [:a, :b], [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5])
41+
@test logpdf_grad(labeled_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...)
2242
@test logpdf(labeled_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf
2343

2444
# Test relabeled distributions with labels provided in a Dict
@@ -55,7 +75,7 @@ end
5575
@test symbol_cat([:a, :b], [0., 1.]) == :b
5676
@test_throws MethodError symbol_cat(["a", "b"], [0., 1.])
5777
@test logpdf(symbol_cat, :c, [:a, :b], [0.5, 0.5]) == -Inf
58-
@test logpdf_grad(symbol_cat, :b, [:a, :b], [0.5, 0.5]) == logpdf_grad(categorical, 2, [0.5, 0.5])
78+
@test logpdf_grad(symbol_cat, :b, [:a, :b], [0.5, 0.5]) == (nothing, logpdf_grad(categorical, 2, [0.5, 0.5])...)
5979
@test_throws MethodError logpdf(symbol_cat, "c", [:a, :b], [0.5, 0.5])
6080

6181
# Test typed parameters

0 commit comments

Comments
 (0)