@@ -213,60 +213,72 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
213
213
end # prod
214
214
215
215
@testset " foldl(f, ::Array)" begin
216
+ # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
217
+ # now attached there, as this is the simplest way to handle `init` keyword.
218
+ @eval using Base: mapfoldl_impl
219
+ @eval _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
220
+
216
221
# Simple
217
- y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init = 1 )
222
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , [1 , 2 , 3 ])
218
223
@test y1 == 6
219
- b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
224
+ @test b1 (7 )[1 : 3 ] == (NoTangent (), NoTangent (), NoTangent ())
225
+ @test b1 (7 )[4 ] isa ChainRulesCore. NotImplemented
226
+ @test b1 (7 )[5 ] == [42 , 21 , 14 ]
220
227
221
- y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
228
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , [1 2 ; 0 4 ]) # without init, needs vcat
222
229
@test y2 == 0
223
- b2 (8 ) == ( NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
230
+ @test b2 (8 )[ 5 ] == [0 0 ; 64 0 ] # matrix, needs reshape
224
231
225
232
# Test execution order
226
233
c5 = Counter ()
227
- y5, b5 = rrule (CFG, foldl, c5 , [5 , 7 , 11 ])
234
+ y5, b5 = rrule (CFG, mapfoldl_impl, identity, c5, _INIT , [5 , 7 , 11 ])
228
235
@test c5 == Counter (2 )
229
236
@test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
230
- @test b5 (1 ) == ( NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
237
+ @test b5 (1 )[ 5 ] == [12 * 32 , 12 * 42 , 22 ]
231
238
@test c5 == Counter (42 )
232
239
233
240
c6 = Counter ()
234
- y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init = 3 )
241
+ y6, b6 = rrule (CFG, mapfoldl_impl, identity, c6, 3 , [5 , 7 , 11 ])
235
242
@test c6 == Counter (3 )
236
243
@test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
237
- @test b6 (1 ) == ( NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
244
+ @test b6 (1 )[ 5 ] == [63 * 33 * 13 , 43 * 13 , 23 ]
238
245
@test c6 == Counter (63 )
239
246
240
247
# Test gradient of function
241
- y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
248
+ y7, b7 = rrule (CFG, mapfoldl_impl, identity, Multiplier (3 ), _INIT , [5 , 7 , 11 ])
242
249
@test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
243
- @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
250
+ b7_1 = b7 (1 )
251
+ @test b7_1[3 ] == Tangent {Multiplier{Int}} (x = 2310 ,)
252
+ @test b7_1[5 ] == [693 , 495 , 315 ]
244
253
245
- y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init = 3 )
254
+ y8, b8 = rrule (CFG, mapfoldl_impl, identity, Multiplier (13 ), 3 , [5 , 7 , 11 ])
246
255
@test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
247
- @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
256
+ b8_1 = b8 (1 )
257
+ @test b8_1[3 ] == Tangent {Multiplier{Int}} (x = 585585 ,)
258
+ @test b8_1[5 ] == [507507 , 362505 , 230685 ]
248
259
# To find these numbers:
249
260
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
250
261
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
251
262
252
263
# Finite differencing
253
- test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
254
- test_rrule (foldl, * , rand (ComplexF64, 3 , 4 ); fkwargs = (; init = rand (ComplexF64) ))
255
- test_rrule (foldl, + , rand (ComplexF64, 7 ); fkwargs = (; init = rand (ComplexF64) ))
256
- test_rrule (foldl, max, rand (3 ); fkwargs = (; init = 999 ))
264
+ test_rrule (mapfoldl_impl, identity, / , _INIT , 1 .+ rand (3 ,4 ))
265
+ test_rrule (mapfoldl_impl, identity, * , rand (ComplexF64), rand (ComplexF64, 3 , 4 ))
266
+ test_rrule (mapfoldl_impl, identity, + , rand (ComplexF64), rand (ComplexF64, 7 ))
267
+ test_rrule (mapfoldl_impl, identity, max, 999 , rand (3 ))
257
268
end
258
269
@testset " foldl(f, ::Tuple)" begin
259
270
y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init= 1 )
271
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , (1 ,2 ,3 ))
260
272
@test y1 == 6
261
- b1 (7 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ) )
273
+ @test b1 (7 )[ 5 ] == Tangent {NTuple{3,Int}} (42 , 21 , 14 )
262
274
263
- y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
275
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , (1 , 2 , 0 , 4 ))
264
276
@test y2 == 0
265
- b2 (8 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ) )
277
+ @test b2 (8 )[ 5 ] == Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 )
266
278
267
279
# Finite differencing
268
- test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
269
- test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
280
+ test_rrule (mapfoldl_impl, identity, / , _INIT , Tuple (1 .+ rand (5 )))
281
+ test_rrule (mapfoldl_impl, identity, * , _INIT , Tuple (rand (ComplexF64, 5 )))
270
282
end
271
283
end
272
284
0 commit comments