@@ -214,60 +214,72 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
214
214
end # prod
215
215
216
216
@testset " foldl(f, ::Array)" begin
217
+ # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
218
+ # now attached there, as this is the simplest way to handle `init` keyword.
219
+ @eval using Base: mapfoldl_impl
220
+ @eval _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
221
+
217
222
# Simple
218
- y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init = 1 )
223
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , [1 , 2 , 3 ])
219
224
@test y1 == 6
220
- b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
225
+ @test b1 (7 )[1 : 3 ] == (NoTangent (), NoTangent (), NoTangent ())
226
+ @test b1 (7 )[4 ] isa ChainRulesCore. NotImplemented
227
+ @test b1 (7 )[5 ] == [42 , 21 , 14 ]
221
228
222
- y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
229
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , [1 2 ; 0 4 ]) # without init, needs vcat
223
230
@test y2 == 0
224
- b2 (8 ) == ( NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
231
+ @test b2 (8 )[ 5 ] == [0 0 ; 64 0 ] # matrix, needs reshape
225
232
226
233
# Test execution order
227
234
c5 = Counter ()
228
- y5, b5 = rrule (CFG, foldl, c5 , [5 , 7 , 11 ])
235
+ y5, b5 = rrule (CFG, mapfoldl_impl, identity, c5, _INIT , [5 , 7 , 11 ])
229
236
@test c5 == Counter (2 )
230
237
@test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
231
- @test b5 (1 ) == ( NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
238
+ @test b5 (1 )[ 5 ] == [12 * 32 , 12 * 42 , 22 ]
232
239
@test c5 == Counter (42 )
233
240
234
241
c6 = Counter ()
235
- y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init = 3 )
242
+ y6, b6 = rrule (CFG, mapfoldl_impl, identity, c6, 3 , [5 , 7 , 11 ])
236
243
@test c6 == Counter (3 )
237
244
@test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
238
- @test b6 (1 ) == ( NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
245
+ @test b6 (1 )[ 5 ] == [63 * 33 * 13 , 43 * 13 , 23 ]
239
246
@test c6 == Counter (63 )
240
247
241
248
# Test gradient of function
242
- y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
249
+ y7, b7 = rrule (CFG, mapfoldl_impl, identity, Multiplier (3 ), _INIT , [5 , 7 , 11 ])
243
250
@test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
244
- @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
251
+ b7_1 = b7 (1 )
252
+ @test b7_1[3 ] == Tangent {Multiplier{Int}} (x = 2310 ,)
253
+ @test b7_1[5 ] == [693 , 495 , 315 ]
245
254
246
- y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init = 3 )
255
+ y8, b8 = rrule (CFG, mapfoldl_impl, identity, Multiplier (13 ), 3 , [5 , 7 , 11 ])
247
256
@test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
248
- @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
257
+ b8_1 = b8 (1 )
258
+ @test b8_1[3 ] == Tangent {Multiplier{Int}} (x = 585585 ,)
259
+ @test b8_1[5 ] == [507507 , 362505 , 230685 ]
249
260
# To find these numbers:
250
261
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
251
262
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
252
263
253
264
# Finite differencing
254
- test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
255
- test_rrule (foldl, * , rand (ComplexF64, 3 , 4 ); fkwargs = (; init = rand (ComplexF64) ))
256
- test_rrule (foldl, + , rand (ComplexF64, 7 ); fkwargs = (; init = rand (ComplexF64) ))
257
- test_rrule (foldl, max, rand (3 ); fkwargs = (; init = 999 ))
265
+ test_rrule (mapfoldl_impl, identity, / , _INIT , 1 .+ rand (3 ,4 ))
266
+ test_rrule (mapfoldl_impl, identity, * , rand (ComplexF64), rand (ComplexF64, 3 , 4 ))
267
+ test_rrule (mapfoldl_impl, identity, + , rand (ComplexF64), rand (ComplexF64, 7 ))
268
+ test_rrule (mapfoldl_impl, identity, max, 999 , rand (3 ))
258
269
end
259
270
@testset " foldl(f, ::Tuple)" begin
260
271
y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init= 1 )
272
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , (1 ,2 ,3 ))
261
273
@test y1 == 6
262
- b1 (7 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ) )
274
+ @test b1 (7 )[ 5 ] == Tangent {NTuple{3,Int}} (42 , 21 , 14 )
263
275
264
- y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
276
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , (1 , 2 , 0 , 4 ))
265
277
@test y2 == 0
266
- b2 (8 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ) )
278
+ @test b2 (8 )[ 5 ] == Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 )
267
279
268
280
# Finite differencing
269
- test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
270
- test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
281
+ test_rrule (mapfoldl_impl, identity, / , _INIT , Tuple (1 .+ rand (5 )))
282
+ test_rrule (mapfoldl_impl, identity, * , _INIT , Tuple (rand (ComplexF64, 5 )))
271
283
end
272
284
end
273
285
0 commit comments