Skip to content

Commit 4c6433a

Browse files
committed
fix accumulate tests
1 parent fff84b5 commit 4c6433a

File tree

1 file changed

+10
-35
lines changed

1 file changed

+10
-35
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,9 @@ end
358358
end
359359
end # cumprod
360360

361-
@testset "accumulate(f, ::Array)" begin
361+
@testset "accumulate(f, ::Vector)" begin
362362
# `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`.
363363
# The rule is now attached there, as this is the simplest way to handle `init` keyword.
364-
@eval using Base: _accumulate!
365364

366365
# Simple
367366
y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1))
@@ -371,9 +370,9 @@ end
371370
@test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}}
372371
@test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented
373372

374-
y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4])
375-
@test y2 accumulate(/, [1 2; 3 4])
376-
@test b2(ones(2, 2))[3] [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
373+
# y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing)
374+
# @test y2 ≈ accumulate(/, [1 2; 3 4.0])
375+
# @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6
377376

378377
# Test execution order
379378
c3 = Counter()
@@ -403,35 +402,11 @@ end
403402
# ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string
404403

405404
# Finite differencing
406-
test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
407-
test_rrule(accumulate, /, 1 .+ rand(3, 4))
408-
test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
405+
# test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand()))
406+
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, Some(rand()))
407+
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
408+
test_rrule(_accumulate!, /, randn(4) NoTangent(), 1 .+ rand(4), nothing, nothing)
409+
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
410+
test_rrule(_accumulate!, ^, randn(6) NoTangent(), 1 .+ rand(6), nothing, Some(rand()))
409411
end
410-
@testset "accumulate(f, ::Tuple)" begin
411-
# Simple
412-
y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
413-
@test y1 == (1, 2, 6, 24)
414-
@test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
415-
416-
# Finite differencing
417-
test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
418-
test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
419-
420-
test_rrule(_accumulate!, *, randn(5) NoTangent(), randn(5), nothing, nothing)
421-
test_rrule(_accumulate!, /, randn(5) NoTangent(), randn(5), nothing, Some(1 + rand()))
422-
# if VERSION >= v"1.5"
423-
# test_rrule(accumulate, /, 1 .+ rand(3, 4))
424-
# test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand()))
425-
# end
426-
end
427-
# VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin
428-
# # Simple
429-
# y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1)
430-
# @test y1 == (1, 2, 6, 24)
431-
# @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6))
432-
433-
# # Finite differencing
434-
# test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand()))
435-
# test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false)
436-
# end
437412
end

0 commit comments

Comments
 (0)