Skip to content

Commit 2b5877c

Browse files
committed
fix accumulate tests
1 parent c7e7f13 commit 2b5877c

File tree

1 file changed

+10
-35
lines changed

1 file changed

+10
-35
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,9 @@ end
359359
end
360360
end # cumprod
361361

362-
@testset "accumulate(f, ::Array)" begin
362+
@testset "accumulate(f, ::Vector)" begin
363363
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
364364
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
365-
@eval using Base: _accumulate!
366365

367366
# Simple
368367
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
@@ -372,9 +371,9 @@ end
372371
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
373372
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
374373

375-
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
376-
@test y2 accumulate(/, [1 2; 3 4])
377-
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
374+
# y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing)
375+
# @test y2 ≈ accumulate(/, [1 2; 3 4.0])
376+
# @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
378377

379378
# Test execution order
380379
c3 = Counter()
@@ -404,35 +403,11 @@ end
404403
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
405404

406405
# Finite differencing
407-
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
408-
test_rrule(accumulate, /, 1 .+ rand(3, 4))
409-
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
406+
# test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
407+
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, Some(rand()))
408+
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
409+
test_rrule(_accumulate!, /, randn(4) NoTangent(), 1 .+ rand(4), nothing, nothing)
410+
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
411+
test_rrule(_accumulate!, ^, randn(6) NoTangent(), 1 .+ rand(6), nothing, Some(rand()))
410412
end
411-
@testset "accumulate(f, ::Tuple)" begin
412-
# Simple
413-
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
414-
@test y1 == (1, 2, 6, 24)
415-
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
416-
417-
# Finite differencing
418-
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
419-
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
420-
421-
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
422-
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
423-
# if VERSION >= v"1.5"
424-
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
425-
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
426-
# end
427-
end
428-
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
429-
# # Simple
430-
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
431-
# @test y1 == (1, 2, 6, 24)
432-
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
433-
434-
# # Finite differencing
435-
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
436-
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
437-
# end
438413
end

0 commit comments

Comments
 (0)