11
22# This file has integration tests for some rules defined in ChainRules.jl,
33# especially those which aim to support higher derivatives, as properly
4- # testing those is difficult.
4+ # testing those is difficult. Organised according to the files in CR.jl.
5+
6+ using Diffractor, ForwardDiff, ChainRulesCore
7+ using Test, LinearAlgebra
8+
9+ using Test: Threw, eval_test
510
6- using Diffractor, ChainRulesCore, ForwardDiff
711
812# ####
913# #### Base/array.jl
@@ -13,7 +17,6 @@ using Diffractor, ChainRulesCore, ForwardDiff
1317
1418
1519
16-
1720# ####
1821# #### Base/arraymath.jl
1922# ####
@@ -33,21 +36,58 @@ using Diffractor, ChainRulesCore, ForwardDiff
3336# #### Base/indexing.jl
3437# ####
3538
39+ @testset " getindex, first" begin
40+ @test_broken gradient (x -> sum (abs2, gradient (first, x)[1 ]), [1 ,2 ,3 ])[1 ] == [0 , 0 , 0 ] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
41+ @test_broken gradient (x -> sum (abs2, gradient (sqrt∘ first, x)[1 ]), [1 ,2 ,3 ])[1 ] ≈ [- 0.25 , 0 , 0 ] # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any)
42+ @test gradient (x -> sum (abs2, gradient (x -> x[1 ]^ 2 , x)[1 ]), [1 ,2 ,3 ])[1 ] == [8 , 0 , 0 ]
43+ @test_broken gradient (x -> sum (abs2, gradient (x -> sum (x[1 : 2 ])^ 2 , x)[1 ]), [1 ,2 ,3 ])[1 ] == [48 , 0 , 0 ] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
44+ end
3645
37-
46+ @testset " eachcol etc" begin
47+ @test gradient (m -> sum (prod, eachcol (m)), [1 2 3 ; 4 5 6 ])[1 ] == [4 5 6 ; 1 2 3 ]
48+ @test gradient (m -> sum (first, eachcol (m)), [1 2 3 ; 4 5 6 ])[1 ] == [1 1 1 ; 0 0 0 ]
49+ @test gradient (m -> sum (first (eachcol (m))), [1 2 3 ; 4 5 6 ])[1 ] == [1 0 0 ; 1 0 0 ]
50+ @test_skip gradient (x -> sum (sin, gradient (m -> sum (first (eachcol (m))), x)[1 ]), [1 2 3 ; 4 5 6 ])[1 ] # MethodError: no method matching one(::Base.OneTo{Int64}), unzip_broadcast, split_bc_forwards
51+ end
3852
3953# ####
4054# #### Base/mapreduce.jl
4155# ####
4256
57+ @testset " sum" begin
58+ @test gradient (x -> sum (abs2, gradient (sum, x)[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
59+ @test gradient (x -> sum (abs2, gradient (x -> sum (abs2, x), x)[1 ]), [1 ,2 ,3 ])[1 ] == [8 ,16 ,24 ]
60+
61+ @test gradient (x -> sum (abs2, gradient (sum, x .^ 2 )[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
62+ @test gradient (x -> sum (abs2, gradient (sum, x .^ 3 )[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
63+ end
4364
65+ @testset " foldl" begin
66+
67+ @test gradient (x -> foldl (* , x), [1 ,2 ,3 ,4 ])[1 ] == [24.0 , 12.0 , 8.0 , 6.0 ]
68+ @test gradient (x -> foldl (* , x; init= 5 ), [1 ,2 ,3 ,4 ])[1 ] == [120.0 , 60.0 , 40.0 , 30.0 ]
69+ @test gradient (x -> foldr (* , x), [1 ,2 ,3 ,4 ])[1 ] == [24 , 12 , 8 , 6 ]
70+
71+ @test gradient (x -> foldl (* , x), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (24.0 , 12.0 , 8.0 , 6.0 )
72+ @test_broken gradient (x -> foldl (* , x; init= 5 ), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (120.0 , 60.0 , 40.0 , 30.0 ) # does not return a Tangent
73+ @test gradient (x -> foldl (* , x; init= 5 ), (1 ,2 ,3 ,4 )) |> only |> Tuple == (120.0 , 60.0 , 40.0 , 30.0 )
74+ @test_broken gradient (x -> foldr (* , x), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (24 , 12 , 8 , 6 )
75+ @test gradient (x -> foldr (* , x), (1 ,2 ,3 ,4 )) |> only |> Tuple == (24 , 12 , 8 , 6 )
76+
77+ end
4478
4579
4680# ####
4781# #### LinearAlgebra/dense.jl
4882# ####
4983
5084
85+ @testset " dot" begin
86+
87+ @test gradient (x -> dot (x, [1 ,2 ,3 ])^ 2 , [4 ,5 ,6 ])[1 ] == [64 ,128 ,192 ]
88+ @test_broken gradient (x -> sum (gradient (x -> dot (x, [1 ,2 ,3 ])^ 2 , x)[1 ]), [4 ,5 ,6 ])[1 ] == [12 ,24 ,36 ] # MethodError: no method matching +(::Tuple{Tangent{ChainRules.var
89+
90+ end
5191
5292
5393# ####
0 commit comments