|
359 | 359 | end
|
360 | 360 | end # cumprod
|
361 | 361 |
|
362 |
| - @testset "accumulate(f, ::Array)" begin |
| 362 | + @testset "accumulate(f, ::Vector)" begin |
363 | 363 | # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
|
364 | 364 | # The rule is now attached there, as this is the simplest way to handle `init` keyword.
|
365 |
| - @eval using Base: _accumulate! |
366 | 365 |
|
367 | 366 | # Simple
|
368 | 367 | y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
|
|
372 | 371 | @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
|
373 | 372 | @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
|
374 | 373 |
|
375 |
| - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) |
376 |
| - @test y2 ≈ accumulate(/, [1 2; 3 4]) |
377 |
| - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
| 374 | + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) |
| 375 | + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) |
| 376 | + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
378 | 377 |
|
379 | 378 | # Test execution order
|
380 | 379 | c3 = Counter()
|
@@ -404,35 +403,11 @@ end
|
404 | 403 | # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
|
405 | 404 |
|
406 | 405 | # Finite differencing
|
407 |
| - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
408 |
| - test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
409 |
| - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 406 | + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
| 407 | + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) |
| 408 | + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
| 409 | + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) |
| 410 | + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 411 | + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) |
410 | 412 | end
|
411 |
| - @testset "accumulate(f, ::Tuple)" begin |
412 |
| - # Simple |
413 |
| - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
414 |
| - @test y1 == (1, 2, 6, 24) |
415 |
| - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
416 |
| - |
417 |
| - # Finite differencing |
418 |
| - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
419 |
| - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
420 |
| - |
421 |
| - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) |
422 |
| - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) |
423 |
| - # if VERSION >= v"1.5" |
424 |
| - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
425 |
| - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
426 |
| - # end |
427 |
| - end |
428 |
| - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin |
429 |
| - # # Simple |
430 |
| - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
431 |
| - # @test y1 == (1, 2, 6, 24) |
432 |
| - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
433 |
| - |
434 |
| - # # Finite differencing |
435 |
| - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
436 |
| - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
437 |
| - # end |
438 | 413 | end
|
0 commit comments