@@ -180,60 +180,71 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
180
180
end # prod
181
181
182
182
@testset " foldl(f, ::Array)" begin
183
+ # `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
184
+ # now attached there, as this is the simplest way to handle `init` keyword.
185
+ @eval using Base: mapfoldl_impl
186
+ @eval _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
187
+
183
188
# Simple
184
- y1, b1 = rrule (CFG, foldl, * , [1 , 2 , 3 ]; init = 1 )
189
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , [1 , 2 , 3 ])
185
190
@test y1 == 6
186
- b1 (7 ) == (NoTangent (), NoTangent (), [42 , 21 , 14 ])
191
+ @test b1 (7 )[1 : 3 ] == (NoTangent (), NoTangent (), NoTangent ())
192
+ @test b1 (7 )[4 ] isa ChainRulesCore. NotImplemented
193
+ @test b1 (7 )[5 ] == [42 , 21 , 14 ]
187
194
188
- y2, b2 = rrule (CFG, foldl, * , [1 2 ; 0 4 ]) # without init, needs vcat
195
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , [1 2 ; 0 4 ]) # without init, needs vcat
189
196
@test y2 == 0
190
- b2 (8 ) == ( NoTangent (), NoTangent (), [0 0 ; 64 0 ]) # matrix, needs reshape
197
+ @test b2 (8 )[ 5 ] == [0 0 ; 64 0 ] # matrix, needs reshape
191
198
192
199
# Test execution order
193
200
c5 = Counter ()
194
- y5, b5 = rrule (CFG, foldl, c5 , [5 , 7 , 11 ])
201
+ y5, b5 = rrule (CFG, mapfoldl_impl, identity, c5, _INIT , [5 , 7 , 11 ])
195
202
@test c5 == Counter (2 )
196
203
@test y5 == ((5 + 7 )* 1 + 11 )* 2 == foldl (Counter (), [5 , 7 , 11 ])
197
- @test b5 (1 ) == ( NoTangent (), NoTangent (), [12 * 32 , 12 * 42 , 22 ])
204
+ @test b5 (1 )[ 5 ] == [12 * 32 , 12 * 42 , 22 ]
198
205
@test c5 == Counter (42 )
199
206
200
207
c6 = Counter ()
201
- y6, b6 = rrule (CFG, foldl, c6, [5 , 7 , 11 ], init = 3 )
208
+ y6, b6 = rrule (CFG, mapfoldl_impl, identity, c6, 3 , [5 , 7 , 11 ])
202
209
@test c6 == Counter (3 )
203
210
@test y6 == (((3 + 5 )* 1 + 7 )* 2 + 11 )* 3 == foldl (Counter (), [5 , 7 , 11 ], init= 3 )
204
- @test b6 (1 ) == ( NoTangent (), NoTangent (), [63 * 33 * 13 , 43 * 13 , 23 ])
211
+ @test b6 (1 )[ 5 ] == [63 * 33 * 13 , 43 * 13 , 23 ]
205
212
@test c6 == Counter (63 )
206
213
207
214
# Test gradient of function
208
- y7, b7 = rrule (CFG, foldl, Multiplier (3 ), [5 , 7 , 11 ])
215
+ y7, b7 = rrule (CFG, mapfoldl_impl, identity, Multiplier (3 ), _INIT , [5 , 7 , 11 ])
209
216
@test y7 == foldl ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
210
- @test b7 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2310 ,), [693 , 495 , 315 ])
217
+ b7_1 = b7 (1 )
218
+ @test b7_1[3 ] == Tangent {Multiplier{Int}} (x = 2310 ,)
219
+ @test b7_1[5 ] == [693 , 495 , 315 ]
211
220
212
- y8, b8 = rrule (CFG, foldl, Multiplier (13 ), [5 , 7 , 11 ], init = 3 )
221
+ y8, b8 = rrule (CFG, mapfoldl_impl, identity, Multiplier (13 ), 3 , [5 , 7 , 11 ])
213
222
@test y8 == 2_537_535 == foldl ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
214
- @test b8 (1 ) == (NoTangent (), Tangent {Multiplier{Int}} (x = 585585 ,), [507507 , 362505 , 230685 ])
223
+ b8_1 = b8 (1 )
224
+ @test b8_1[3 ] == Tangent {Multiplier{Int}} (x = 585585 ,)
225
+ @test b8_1[5 ] == [507507 , 362505 , 230685 ]
215
226
# To find these numbers:
216
227
# ForwardDiff.derivative(z -> foldl((x,y)->x*y*z, [5,7,11], init=3), 13)
217
228
# ForwardDiff.gradient(z -> foldl((x,y)->x*y*13, z, init=3), [5,7,11]) |> string
218
229
219
230
# Finite differencing
220
- test_rrule (foldl, / , 1 .+ rand (3 ,4 ))
221
- test_rrule (foldl, * , rand (ComplexF64, 3 , 4 ); fkwargs = (; init = rand (ComplexF64) ))
222
- test_rrule (foldl, + , rand (ComplexF64, 7 ); fkwargs = (; init = rand (ComplexF64) ))
223
- test_rrule (foldl, max, rand (3 ); fkwargs = (; init = 999 ))
231
+ test_rrule (mapfoldl_impl, identity, / , _INIT , 1 .+ rand (3 ,4 ))
232
+ test_rrule (mapfoldl_impl, identity, * , rand (ComplexF64), rand (ComplexF64, 3 , 4 ))
233
+ test_rrule (mapfoldl_impl, identity, + , rand (ComplexF64), rand (ComplexF64, 7 ))
234
+ test_rrule (mapfoldl_impl, identity, max, 999 , rand (3 ))
224
235
end
225
236
VERSION >= v " 1.5" && @testset " foldl(f, ::Tuple)" begin
226
- y1, b1 = rrule (CFG, foldl, * , (1 ,2 ,3 ); init = 1 )
237
+ y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , (1 ,2 ,3 ))
227
238
@test y1 == 6
228
- b1 (7 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{3,Int}} (42 , 21 , 14 ) )
239
+ @test b1 (7 )[ 5 ] == Tangent {NTuple{3,Int}} (42 , 21 , 14 )
229
240
230
- y2, b2 = rrule (CFG, foldl, * , (1 , 2 , 0 , 4 ))
241
+ y2, b2 = rrule (CFG, mapfoldl_impl, identity, * , _INIT , (1 , 2 , 0 , 4 ))
231
242
@test y2 == 0
232
- b2 (8 ) == ( NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 ) )
243
+ @test b2 (8 )[ 5 ] == Tangent {NTuple{4,Int}} (0 , 0 , 64 , 0 )
233
244
234
245
# Finite differencing
235
- test_rrule (foldl, / , Tuple (1 .+ rand (5 )))
236
- test_rrule (foldl, * , Tuple (rand (ComplexF64, 5 )))
246
+ test_rrule (mapfoldl_impl, identity, / , _INIT , Tuple (1 .+ rand (5 )))
247
+ test_rrule (mapfoldl_impl, identity, * , _INIT , Tuple (rand (ComplexF64, 5 )))
237
248
end
238
249
end
239
250
0 commit comments