|
358 | 358 | end
|
359 | 359 | end # cumprod
|
360 | 360 |
|
361 |
| - @testset "accumulate(f, ::Array)" begin |
| 361 | + @testset "accumulate(f, ::Vector)" begin |
362 | 362 | # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
|
363 | 363 | # The rule is now attached there, as this is the simplest way to handle `init` keyword.
|
364 |
| - @eval using Base: _accumulate! |
365 | 364 |
|
366 | 365 | # Simple
|
367 | 366 | y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
|
|
371 | 370 | @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
|
372 | 371 | @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
|
373 | 372 |
|
374 |
| - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) |
375 |
| - @test y2 ≈ accumulate(/, [1 2; 3 4]) |
376 |
| - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
| 373 | + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) |
| 374 | + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) |
| 375 | + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 |
377 | 376 |
|
378 | 377 | # Test execution order
|
379 | 378 | c3 = Counter()
|
@@ -403,35 +402,11 @@ end
|
403 | 402 | # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
|
404 | 403 |
|
405 | 404 | # Finite differencing
|
406 |
| - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
407 |
| - test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
408 |
| - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 405 | + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) |
| 406 | + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) |
| 407 | + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
| 408 | + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) |
| 409 | + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
| 410 | + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) |
409 | 411 | end
|
410 |
| - @testset "accumulate(f, ::Tuple)" begin |
411 |
| - # Simple |
412 |
| - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
413 |
| - @test y1 == (1, 2, 6, 24) |
414 |
| - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
415 |
| - |
416 |
| - # Finite differencing |
417 |
| - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
418 |
| - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
419 |
| - |
420 |
| - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) |
421 |
| - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) |
422 |
| - # if VERSION >= v"1.5" |
423 |
| - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) |
424 |
| - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) |
425 |
| - # end |
426 |
| - end |
427 |
| - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin |
428 |
| - # # Simple |
429 |
| - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) |
430 |
| - # @test y1 == (1, 2, 6, 24) |
431 |
| - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) |
432 |
| - |
433 |
| - # # Finite differencing |
434 |
| - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) |
435 |
| - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) |
436 |
| - # end |
437 | 412 | end
|
0 commit comments