@@ -183,7 +183,7 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
183
183
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
184
184
# now attached there, as this is the simplest way to handle `init` keyword.
185
185
@eval using Base: mapfoldl_impl
186
- @eval _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
186
+ _INIT = VERSION >= v " 1.5" ? Base. _InitialValue () : NamedTuple ()
187
187
188
188
# Simple
189
189
y1, b1 = rrule (CFG, mapfoldl_impl, identity, * , 1 , [1 , 2 , 3 ])
@@ -286,57 +286,67 @@ end
286
286
end # cumprod
287
287
288
288
@testset " accumulate(f, ::Array)" begin
289
+ # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
290
+ # The rule is now attached there, as this is the simplest way to handle `init` keyword.
291
+ @eval using Base: _accumulate!
292
+
289
293
# Simple
290
- y1, b1 = rrule (CFG, accumulate , * , [1 , 2 , 3 , 4 ]; init = 1 )
294
+ y1, b1 = rrule (CFG, _accumulate! , * , [0 , 0 , 0 , 0 ], [ 1 , 2 , 3 , 4 ], nothing , Some ( 1 ) )
291
295
@test y1 == [1 , 2 , 6 , 24 ]
292
- @test b1 ([1 , 1 , 1 , 1 ]) == (NoTangent (), NoTangent (), [33 , 16 , 10 , 6 ])
296
+ @test b1 ([1 , 1 , 1 , 1 ])[3 ] isa ChainRulesCore. NotImplemented
297
+ @test b1 ([1 , 1 , 1 , 1 ])[4 ] == [33 , 16 , 10 , 6 ]
298
+ @test b1 ([1 , 1 , 1 , 1 ])[6 ] isa Tangent{Some{Int64}}
299
+ @test b1 ([1 , 1 , 1 , 1 ])[6 ]. value isa ChainRulesCore. NotImplemented
293
300
294
- if VERSION >= v " 1.5"
295
- y2, b2 = rrule (CFG, accumulate , / , [1 2 ; 3 4 ])
296
- @test y2 ≈ accumulate (/ , [1 2 ; 3 4 ])
297
- @test b2 (ones (2 , 2 ))[3 ] ≈ [1.5416666 - 0.104166664 ; - 0.18055555 - 0.010416667 ] atol= 1e-6
298
- end
301
+ # if VERSION >= v"1.5"
302
+ # y2, b2 = rrule(CFG, _accumulate! , /, [1 2; 3 4])
303
+ # @test y2 ≈ accumulate(/, [1 2; 3 4])
304
+ # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
305
+ # end
299
306
300
307
# Test execution order
301
308
c3 = Counter ()
302
- y3, b3 = rrule (CFG, accumulate , c3, [5 , 7 , 11 ]; init = 3 )
309
+ y3, b3 = rrule (CFG, _accumulate! , c3, [0 , 0 , 0 ], [ 5 , 7 , 11 ], nothing , Some ( 3 ) )
303
310
@test c3 == Counter (3 )
304
311
@test y3 == [8 , 30 , 123 ] == accumulate (Counter (), [5 , 7 , 11 ]; init= 3 )
305
- @test b3 ([1 , 1 , 1 ]) == ( NoTangent (), NoTangent (), [29169 , 602 , 23 ]) # the 23 is clear!
312
+ @test b3 ([1 , 1 , 1 ])[ 4 ] == [29169 , 602 , 23 ] # the 23 is clear!
306
313
307
314
c4 = Counter ()
308
- y4, b4 = rrule (CFG, accumulate , c4, [5 , 7 , 11 ])
315
+ y4, b4 = rrule (CFG, _accumulate! , c4, [0 , 0 , 0 ], [ 5 , 7 , 11 ], nothing , nothing )
309
316
@test c4 == Counter (2 )
310
317
@test y4 == [5 , (5 + 7 )* 1 , ((5 + 7 )* 1 + 11 )* 2 ] == accumulate (Counter (), [5 , 7 , 11 ])
311
- @test b4 ([1 , 1 , 1 ]) == ( NoTangent (), NoTangent (), [417 , 42 * (1 + 12 ), 22 ])
318
+ @test b4 ([1 , 1 , 1 ])[ 4 ] == [417 , 42 * (1 + 12 ), 22 ]
312
319
313
320
# Test gradient of function
314
- y7, b7 = rrule (CFG, accumulate , Multiplier (3 ), [5 , 7 , 11 ])
321
+ y7, b7 = rrule (CFG, _accumulate! , Multiplier (3 ), [0 , 0 , 0 ], [ 5 , 7 , 11 ], nothing , nothing )
315
322
@test y7 == accumulate ((x,y)-> x* y* 3 , [5 , 7 , 11 ])
316
- @test b7 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 2345 ,), [715 , 510 , 315 ])
323
+ @test b7 ([1 , 1 , 1 ])[2 ] == Tangent {Multiplier{Int}} (; x = 2345 ,)
324
+ @test b7 ([1 , 1 , 1 ])[4 ] == [715 , 510 , 315 ]
317
325
318
- y8, b8 = rrule (CFG, accumulate , Multiplier (13 ), [5 , 7 , 11 ], init = 3 )
326
+ y8, b8 = rrule (CFG, _accumulate! , Multiplier (13 ), [0 , 0 , 0 ], [ 5 , 7 , 11 ], nothing , Some ( 3 ) )
319
327
@test y8 == [195 , 17745 , 2537535 ] == accumulate ((x,y)-> x* y* 13 , [5 , 7 , 11 ], init= 3 )
320
- @test b8 ([1 , 1 , 1 ]) == (NoTangent (), Tangent {Multiplier{Int}} (x = 588330 ,), [511095 , 365040 , 230685 ])
328
+ @test b8 ([1 , 1 , 1 ])[2 ] == Tangent {Multiplier{Int}} (; x = 588330 ,)
329
+ @test b8 ([1 , 1 , 1 ])[4 ] == [511095 , 365040 , 230685 ]
321
330
# To find these numbers:
322
331
# ForwardDiff.derivative(z -> sum(accumulate((x,y)->x*y*z, [5,7,11], init=3)), 13)
323
332
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
324
333
325
334
# Finite differencing
326
- test_rrule (accumulate, * , randn (5 ); fkwargs= (; init= rand ()))
327
- if VERSION >= v " 1.5"
328
- test_rrule (accumulate, / , 1 .+ rand (3 , 4 ))
329
- test_rrule (accumulate, ^ , 1 .+ rand (2 , 3 ); fkwargs= (; init= rand ()))
330
- end
331
- end
332
- VERSION >= v " 1.5" && @testset " accumulate(f, ::Tuple)" begin
333
- # Simple
334
- y1, b1 = rrule (CFG, accumulate, * , (1 , 2 , 3 , 4 ); init= 1 )
335
- @test y1 == (1 , 2 , 6 , 24 )
336
- @test b1 ((1 , 1 , 1 , 1 )) == (NoTangent (), NoTangent (), Tangent {NTuple{4,Int}} (33 , 16 , 10 , 6 ))
337
-
338
- # Finite differencing
339
- test_rrule (accumulate, * , Tuple (randn (5 )); fkwargs= (; init= rand ()))
340
- test_rrule (accumulate, / , Tuple (1 .+ rand (5 )); check_inferred= false )
335
+ test_rrule (_accumulate!, * , randn (5 ) ⊢ NoTangent (), randn (5 ), nothing , nothing )
336
+ test_rrule (_accumulate!, / , randn (5 ) ⊢ NoTangent (), randn (5 ), nothing , Some (1 + rand ()))
337
+ # if VERSION >= v"1.5"
338
+ # test_rrule(accumulate, /, 1 .+ rand(3, 4))
339
+ # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
340
+ # end
341
341
end
342
+ # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
343
+ # # Simple
344
+ # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
345
+ # @test y1 == (1, 2, 6, 24)
346
+ # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
347
+
348
+ # # Finite differencing
349
+ # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
350
+ # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
351
+ # end
342
352
end
0 commit comments