2
2
m1 = collect (1 : 3.0 )
3
3
m2 = (collect (1 : 3.0 ), collect (4 : 6.0 ))
4
4
m3 = (x = m1, y = sin, z = collect (4 : 6.0 ))
5
+
5
6
m4 = (x = m1, y = m1, z = collect (4 : 6.0 )) # tied
6
7
m5 = (a = (m3, true ), b = (m1, false ), c = (m4, true ))
7
8
m6 = (a = m1, b = [4.0 + im], c = m1)
9
+
8
10
m7 = TwoThirds ((sin, collect (1 : 3.0 )), (cos, collect (4 : 6.0 )), (tan, collect (7 : 9.0 )))
9
11
m8 = [Foo (m1, m1), (a = true , b = Foo ([4.0 ], false ), c = ()), [[5.0 ]]]
10
12
13
+ mat = Float32[4 6 ; 5 7 ]
14
+ m9 = (a = m1, b = mat, c = [mat, m1])
15
+
11
16
@testset " flatten & rebuild" begin
12
17
@test destructure (m1)[1 ] isa Vector{Float64}
13
18
@test destructure (m1)[1 ] == 1 : 3
@@ -16,6 +21,7 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
16
21
@test destructure (m4)[1 ] == 1 : 6
17
22
@test destructure (m5)[1 ] == vcat (1 : 6 , 4 : 6 )
18
23
@test destructure (m6)[1 ] == vcat (1 : 3 , 4 + im)
24
+ @test destructure (m9)[1 ] == 1 : 7
19
25
20
26
@test destructure (m1)[2 ](7 : 9 ) == [7 ,8 ,9 ]
21
27
@test destructure (m2)[2 ](4 : 9 ) == ([4 ,5 ,6 ], [7 ,8 ,9 ])
@@ -45,6 +51,10 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
45
51
@test m8′[2 ]. b. y === false
46
52
@test m8′[3 ][1 ] == [5.0 ]
47
53
54
+ m9′ = destructure (m9)[2 ](10 : 10 : 70 )
55
+ @test m9′. b === m9′. c[1 ]
56
+ @test m9′. b isa Matrix{Float32}
57
+
48
58
# errors
49
59
@test_throws Exception destructure (m7)[2 ]([10 ,20 ])
50
60
@test_throws Exception destructure (m7)[2 ]([10 ,20 ,30 ,40 ])
71
81
@test g8[2 ]. b. x == [8 ]
72
82
@test g8[3 ] == [[10.0 ]]
73
83
84
+ g9 = gradient (m -> sum (sqrt, destructure (m)[1 ]), m9)[1 ]
85
+ @test g9. c === nothing
86
+
74
87
@testset " second derivative" begin
75
88
@test gradient ([1 ,2 ,3.0 ]) do v
76
89
sum (abs2, gradient (m -> sum (abs2, destructure (m)[1 ]), (v, [4 ,5 ,6.0 ]))[1 ][1 ])
119
132
@test gradient (x -> sum (abs2, re8 (x)[1 ]. y), v8)[1 ] == [2 ,4 ,6 ,0 ,0 ]
120
133
@test gradient (x -> only (sum (re8 (x)[3 ]))^ 2 , v8)[1 ] == [0 ,0 ,0 ,0 ,10 ]
121
134
135
+ re9 = destructure (m9)[2 ]
136
+ @test gradient (x -> sum (abs2, re9 (x). c[1 ]), 1 : 7 )[1 ] == [0 ,0 ,0 , 8 ,10 ,12 ,14 ]
137
+
122
138
@testset " second derivative" begin
123
139
@test_broken gradient (collect (1 : 6.0 )) do y
124
140
sum (abs2, gradient (x -> sum (abs2, re2 (x)[1 ]), y)[1 ])
0 commit comments