Skip to content

Commit acd7005

Browse files
authored
Merge pull request #433 from MrVPlusOne/fix-broadcasted-normal
Fix `logpdf_grad` for BroadcastedNormal.
2 parents d79aded + 5d3f192 commit acd7005

File tree

3 files changed

+97
-23
lines changed

3 files changed

+97
-23
lines changed

src/modeling_library/distributions/normal.jl

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Samples an `Array{Float64, max(N1, N2)}` of shape
3434
`Broadcast.broadcast_shapes(size(mu), size(std))` where each element is
3535
independently normally distributed. This is equivalent to (a reshape of) a
3636
multivariate normal with diagonal covariance matrix, but its implementation is
37-
more efficient than that of the more general `mvnormal` for this case.
37+
more efficient than that of the more general [`mvnormal`](@ref) for this case.
3838
3939
The shapes of `mu` and `std` must be broadcast-compatible.
4040
@@ -65,8 +65,6 @@ function logpdf(::BroadcastedNormal,
6565
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
6666
msg="Shape of `x` does not agree with the sample space")
6767
z = (x .- mu) ./ std
68-
var = std .* std
69-
diff = x .- mu
7068
sum(- (abs2.(z) .+ log(2π)) / 2 .- log.(std))
7169
end
7270

@@ -85,10 +83,46 @@ function logpdf_grad(::BroadcastedNormal,
8583
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
8684
msg="Shape of `x` does not agree with the sample space")
8785
z = (x .- mu) ./ std
88-
deriv_x = sum(- z ./ std)
86+
deriv_x = - z ./ std
8987
deriv_mu = -deriv_x
90-
deriv_std = sum(-1. ./ std .+ abs2.(z) ./ std)
91-
(deriv_x, deriv_mu, deriv_std)
88+
deriv_std = -1. ./ std .+ abs2.(z) ./ std
89+
(_unbroadcast_like(x, deriv_x),
90+
_unbroadcast_like(mu, deriv_mu),
91+
_unbroadcast_like(std, deriv_std))
92+
end
93+
94+
_unbroadcast_like(::Real, full_arr) = sum(full_arr)
95+
_unbroadcast_like(::AbstractArray{<:Real, 0}, full_arr::Real) = fill(full_arr)
96+
function _unbroadcast_like(a::AbstractArray{<:Real, N},
97+
full_arr::AbstractArray{T}
98+
)::AbstractArray{T, N} where {N,T}
99+
if size(a) == size(full_arr)
100+
return full_arr
101+
end
102+
return _unbroadcast_to_shape(size(a), full_arr)
103+
end
104+
105+
"""
106+
"Unbroadcasts" `full_arr` to have shape `target_shape` by:
107+
108+
* Summing over all dims that would be increased by a broadcast from shape
109+
`target_shape` to shape `size(full_arr)`
110+
* Then dropping trailing dims (which will all be 1's) as needed so that the
111+
result has shape `target_shape`.
112+
113+
Requires that `size(full_arr)` is "strictly bigger" than `target_shape`, in the
114+
sense that
115+
116+
Broadcast.broadcast_shapes(target_shape, size(full_arr)) == size(full_arr)
117+
"""
118+
function _unbroadcast_to_shape(target_shape::NTuple{target_ndims, Int},
119+
full_arr::AbstractArray{T, full_ndims}
120+
) where {T, target_ndims, full_ndims}
121+
@assert full_ndims >= target_ndims
122+
should_sum_dim(i) = (i > target_ndims) || (target_shape[i] == 1 &&
123+
size(full_arr, i) > 1)
124+
dropdims(sum(full_arr; dims=filter(should_sum_dim, 1:full_ndims));
125+
dims=Dims(target_ndims + 1 : full_ndims))
92126
end
93127

94128
random(::Normal, mu::Real, std::Real) = mu + std * randn()

test/modeling_library/distributions.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,14 @@ end
123123
x = broadcasted_normal(fill(0), fill(1))
124124

125125
# logpdf_grad
126-
f = (x, mu, std) -> logpdf(broadcasted_normal, x, mu, std)
126+
f(x, mu, std) = logpdf(broadcasted_normal, x, mu, std)
127127
args = (fill(0.4), fill(0.2), fill(0.3))
128128
actual = logpdf_grad(broadcasted_normal, args...)
129+
130+
@test actual[1] isa AbstractArray && size(actual[1]) == ()
131+
@test actual[2] isa AbstractArray && size(actual[2]) == ()
132+
@test actual[3] isa AbstractArray && size(actual[3]) == ()
133+
129134
@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true))
130135
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true))
131136
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true))
@@ -144,27 +149,37 @@ end
144149
broadcasted_normal(mu, std)
145150

146151
# logpdf_grad
147-
f = (x, mu, std) -> logpdf(broadcasted_normal, x, mu, std)
148-
args = (mu, std, x)
152+
f(x_, mu_, std_) = logpdf(broadcasted_normal, x_, mu_, std_)
153+
args = (x, mu, std)
149154
actual = logpdf_grad(broadcasted_normal, args...)
150-
@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true))
151-
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true))
152-
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true))
155+
156+
@test actual[1] isa AbstractArray && size(actual[1]) == (2, 3)
157+
@test actual[2] isa AbstractArray && size(actual[2]) == (2, 3)
158+
@test actual[3] isa AbstractArray && size(actual[3]) == (2, 3)
159+
160+
@test isapprox(actual[1], finite_diff_arr_fullarg(f, args, 1, dx); rtol=1e-7)
161+
@test isapprox(actual[2], finite_diff_arr_fullarg(f, args, 2, dx); rtol=1e-7)
162+
@test isapprox(actual[3], finite_diff_arr_fullarg(f, args, 3, dx); rtol=1e-7)
153163
end
154164

155165
@testset "broadcasted normal" begin
156166

157167
## Return shape of `broadcasted_normal`
158168
@test size(broadcasted_normal([0. 0. 0.], 1.)) == (1, 3)
159169
@test size(broadcasted_normal(zeros(1, 3, 4), ones(2, 1, 4))) == (2, 3, 4)
170+
@test size(broadcasted_normal(zeros(1, 3), ones(2, 1, 1))) == (2, 3, 1)
160171
@test_throws DimensionMismatch broadcasted_normal([0 0 0], [1 1])
172+
# Numpy and Julia use different conventions for which direction the
173+
# implicit 1-padding goes. In Julia, it's not `(1, 2, 3)` but rather
174+
# `(2, 3, 1)` that is broadcast-compatible with the shape `(2, 3)`.
175+
@test_throws DimensionMismatch broadcasted_normal(zeros(2, 3), ones(1, 2, 3))
161176

162177
## Return shape of `logpdf` and `logpdf_grad`
163178
@test size(logpdf(broadcasted_normal,
164179
ones(2, 4), ones(2, 1), ones(1, 4))) == ()
165-
@test all(size(g) == ()
166-
for g in logpdf_grad(
167-
broadcasted_normal, ones(2, 4), ones(2, 1), ones(1, 4)))
180+
@test [size(g) for g in logpdf_grad(
181+
broadcasted_normal, ones(2, 4), ones(2, 1), ones(1, 4))
182+
] == [(2, 4), (2, 1), (1, 4)]
168183
# `x` has the wrong shape
169184
@test_throws DimensionMismatch logpdf(broadcasted_normal,
170185
ones(1, 2), ones(1,3), ones(2,1))
@@ -182,21 +197,20 @@ end
182197
@test_throws DimensionMismatch logpdf_grad(broadcasted_normal,
183198
ones(2, 1), ones(1,2), ones(1,3))
184199

185-
## Equivalence of broadcast to supplying bigger arrays for `mu` and `std`
200+
## For `logpdf`, equivalence of broadcast to supplying bigger arrays for
201+
## `mu` and `std`
186202
compact = OrderedDict(:x => reshape([ 0.2 0.3 0.4 0.5 ;
187203
0.5 0.4 0.3 0.2 ],
188-
(2, 4)),
204+
(2, 4, 1)),
189205
:mu => reshape([0.7 0.7 0.8 0.6],
190206
(1, 4)),
191207
:std => reshape([0.2, 0.1],
192-
(2, 1)))
208+
(2, 1, 1)))
193209
expanded = OrderedDict(:x => compact[:x],
194-
:mu => repeat(compact[:mu], outer=(2, 1)),
195-
:std => repeat(compact[:std], outer=(1, 4)))
210+
:mu => repeat(compact[:mu], outer=(2, 1, 1)),
211+
:std => repeat(compact[:std], outer=(1, 4, 1)))
196212
@test (logpdf(broadcasted_normal, values(compact)...) ==
197213
logpdf(broadcasted_normal, values(expanded)...))
198-
@test (logpdf_grad(broadcasted_normal, values(compact)...) ==
199-
logpdf_grad(broadcasted_normal, values(expanded)...))
200214
end
201215

202216
@testset "multivariate normal" begin

test/runtests.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ function finite_diff(f::Function, args::Tuple, i::Int, dx::Float64;
2121
if broadcast
2222
pos_args[i] = copy(args[i]) .+ dx
2323
neg_args[i] = copy(args[i]) .- dx
24-
return (f(pos_args...) - f(neg_args...)) ./ (2. * dx)
24+
ans = (f(pos_args...) - f(neg_args...)) ./ (2. * dx)
25+
# Workaround for
26+
# https://github.com/probcomp/Gen.jl/pull/433#discussion_r669958584
27+
if args[i] isa AbstractArray && ndims(args[i]) == 0
28+
return fill(ans)
29+
end
30+
return ans
2531
else
2632
pos_args[i] += dx
2733
neg_args[i] -= dx
@@ -74,6 +80,26 @@ function finite_diff_arr(f::Function, args::Tuple, i::Int, idx, dx::Float64)
7480
return (f(pos_args...) - f(neg_args...)) / (2. * dx)
7581
end
7682

83+
"""
84+
Returns the partial derivatives of `f` with respect to all entries of
85+
`args[i]`.
86+
87+
That is, returns an array of the same shape as `args[i]`, each entry of which
88+
is [`finite_diff_arr`](@ref) applied to the corresponding entry of `args[i]`.
89+
90+
Requires that `args[i]` have nonzero rank. Due to [1], handling
91+
zero-dimensional arrays properly in this function is not feasible; the caller
92+
should handle that case on their own.
93+
94+
[1] https://github.com/JuliaLang/julia/issues/28866
95+
"""
96+
function finite_diff_arr_fullarg(f::Function, args::Tuple, i::Int, dx::Float64)
97+
@assert args[i] isa AbstractArray
98+
@assert ndims(args[i]) > 0
99+
return [finite_diff_arr(f, args, i, idx, dx)
100+
for idx in keys(args[i])]
101+
end
102+
77103
const dx = 1e-6
78104

79105
include("autodiff.jl")

0 commit comments

Comments
 (0)