123123 x = broadcasted_normal (fill (0 ), fill (1 ))
124124
125125 # logpdf_grad
126- f = (x, mu, std) -> logpdf (broadcasted_normal, x, mu, std)
126+ f (x, mu, std) = logpdf (broadcasted_normal, x, mu, std)
127127 args = (fill (0.4 ), fill (0.2 ), fill (0.3 ))
128128 actual = logpdf_grad (broadcasted_normal, args... )
129+
130+ @test actual[1 ] isa AbstractArray && size (actual[1 ]) == ()
131+ @test actual[2 ] isa AbstractArray && size (actual[2 ]) == ()
132+ @test actual[3 ] isa AbstractArray && size (actual[3 ]) == ()
133+
129134 @test isapprox (actual[1 ], finite_diff (f, args, 1 , dx; broadcast= true ))
130135 @test isapprox (actual[2 ], finite_diff (f, args, 2 , dx; broadcast= true ))
131136 @test isapprox (actual[3 ], finite_diff (f, args, 3 , dx; broadcast= true ))
@@ -144,27 +149,37 @@ end
144149 broadcasted_normal (mu, std)
145150
146151 # logpdf_grad
147- f = (x, mu, std) -> logpdf (broadcasted_normal, x, mu, std )
148- args = (mu, std, x )
152+ f (x_, mu_, std_) = logpdf (broadcasted_normal, x_, mu_, std_ )
153+ args = (x, mu, std )
149154 actual = logpdf_grad (broadcasted_normal, args... )
150- @test isapprox (actual[1 ], finite_diff (f, args, 1 , dx; broadcast= true ))
151- @test isapprox (actual[2 ], finite_diff (f, args, 2 , dx; broadcast= true ))
152- @test isapprox (actual[3 ], finite_diff (f, args, 3 , dx; broadcast= true ))
155+
156+ @test actual[1 ] isa AbstractArray && size (actual[1 ]) == (2 , 3 )
157+ @test actual[2 ] isa AbstractArray && size (actual[2 ]) == (2 , 3 )
158+ @test actual[3 ] isa AbstractArray && size (actual[3 ]) == (2 , 3 )
159+
160+ @test isapprox (actual[1 ], finite_diff_arr_fullarg (f, args, 1 , dx); rtol= 1e-7 )
161+ @test isapprox (actual[2 ], finite_diff_arr_fullarg (f, args, 2 , dx); rtol= 1e-7 )
162+ @test isapprox (actual[3 ], finite_diff_arr_fullarg (f, args, 3 , dx); rtol= 1e-7 )
153163end
154164
155165@testset " broadcasted normal" begin
156166
157167 # # Return shape of `broadcasted_normal`
158168 @test size (broadcasted_normal ([0. 0. 0. ], 1. )) == (1 , 3 )
159169 @test size (broadcasted_normal (zeros (1 , 3 , 4 ), ones (2 , 1 , 4 ))) == (2 , 3 , 4 )
170+ @test size (broadcasted_normal (zeros (1 , 3 ), ones (2 , 1 , 1 ))) == (2 , 3 , 1 )
160171 @test_throws DimensionMismatch broadcasted_normal ([0 0 0 ], [1 1 ])
172+ # Numpy and Julia use different conventions for which direction the
173+ # implicit 1-padding goes. In Julia, it's not `(1, 2, 3)` but rather
174+ # `(2, 3, 1)` that is broadcast-compatible with the shape `(2, 3)`.
175+ @test_throws DimensionMismatch broadcasted_normal (zeros (2 , 3 ), ones (1 , 2 , 3 ))
161176
162177 # # Return shape of `logpdf` and `logpdf_grad`
163178 @test size (logpdf (broadcasted_normal,
164179 ones (2 , 4 ), ones (2 , 1 ), ones (1 , 4 ))) == ()
165- @test all ( size (g) == ()
166- for g in logpdf_grad (
167- broadcasted_normal, ones (2 , 4 ), ones (2 , 1 ), ones (1 , 4 )))
180+ @test [ size (g) for g in logpdf_grad (
181+ broadcasted_normal, ones ( 2 , 4 ), ones ( 2 , 1 ), ones ( 1 , 4 ))
182+ ] == [ (2 , 4 ), (2 , 1 ), (1 , 4 )]
168183 # `x` has the wrong shape
169184 @test_throws DimensionMismatch logpdf (broadcasted_normal,
170185 ones (1 , 2 ), ones (1 ,3 ), ones (2 ,1 ))
@@ -182,21 +197,20 @@ end
182197 @test_throws DimensionMismatch logpdf_grad (broadcasted_normal,
183198 ones (2 , 1 ), ones (1 ,2 ), ones (1 ,3 ))
184199
185- # # Equivalence of broadcast to supplying bigger arrays for `mu` and `std`
200+ # # For `logpdf`, equivalence of broadcast to supplying bigger arrays for
201+ # # `mu` and `std`
186202 compact = OrderedDict (:x => reshape ([ 0.2 0.3 0.4 0.5 ;
187203 0.5 0.4 0.3 0.2 ],
188- (2 , 4 )),
204+ (2 , 4 , 1 )),
189205 :mu => reshape ([0.7 0.7 0.8 0.6 ],
190206 (1 , 4 )),
191207 :std => reshape ([0.2 , 0.1 ],
192- (2 , 1 )))
208+ (2 , 1 , 1 )))
193209 expanded = OrderedDict (:x => compact[:x ],
194- :mu => repeat (compact[:mu ], outer= (2 , 1 )),
195- :std => repeat (compact[:std ], outer= (1 , 4 )))
210+ :mu => repeat (compact[:mu ], outer= (2 , 1 , 1 )),
211+ :std => repeat (compact[:std ], outer= (1 , 4 , 1 )))
196212 @test (logpdf (broadcasted_normal, values (compact)... ) ==
197213 logpdf (broadcasted_normal, values (expanded)... ))
198- @test (logpdf_grad (broadcasted_normal, values (compact)... ) ==
199- logpdf_grad (broadcasted_normal, values (expanded)... ))
200214end
201215
202216@testset " multivariate normal" begin
0 commit comments