1
- using Optimisers, Test
2
- using Zygote
3
- using Statistics, Random, LinearAlgebra
4
- Random. seed! (1 )
1
+ using Optimisers, Functors, Zygote
2
+ using LinearAlgebra, Statistics, Test, Random
5
3
using Optimisers: @. .
6
4
5
+ Random. seed! (1 )
6
+
7
+ struct Foo; x; y; end
8
+ Functors. @functor Foo
9
+ Optimisers. trainable (x:: Foo ) = (x. y, x. x)
10
+
11
+ struct TwoThirds a; b; c; end
12
+ Functors. @functor TwoThirds (a, c)
13
+ Optimisers. trainable (x:: TwoThirds ) = (a = x. a,)
14
+
7
15
@testset verbose= true " Optimisers.jl" begin
8
16
9
17
@testset " very basics" begin
@@ -23,7 +31,7 @@ using Optimisers: @..
23
31
@test m3[1 ] ≈ [1 ,2 ] .- 0.1 .* [25 , 33 ]
24
32
end
25
33
26
- @testset " $(first (string (o), 42 )) " for o in (
34
+ @testset " rule: $(first (string (o), 42 )) " for o in (
27
35
Descent (), ADAM (), Momentum (), Nesterov (), RMSProp (),
28
36
ADAGrad (), AdaMax (), ADADelta (), AMSGrad (), NADAM (),
29
37
ADAMW (), RADAM (), OADAM (), AdaBelief ()
@@ -99,6 +107,38 @@ using Optimisers: @..
99
107
@test isnan (m3n. γ[3 ])
100
108
end
101
109
110
+ @testset " trainable subset" begin
111
+ # Foo has an old-style tuple trainable, both elements
112
+ mf = Foo ([1 ,2 ], (a = sin, b = [3 ,4 ], c = 5 ))
113
+ sf = Optimisers. setup (Descent (0.1 ), mf)
114
+ gf = (x = nothing , y = (a = nothing , b = [1 ,1 ], c = 1 ))
115
+ _, mf2 = Optimisers. update (sf, mf, gf)
116
+ @test mf2. x == [1 ,2 ]
117
+ @test mf2. y == (a = sin, b = [2.9 , 3.9 ], c = 5 )
118
+
119
+ # TwoThirds has functor a,c only, and trainable a only
120
+ mt = TwoThirds (Float32[1 ,2 ], Float32[3 ,4 ], Float32[5 ,6 ])
121
+ mt10 = fmap (x -> 10 x, mt)
122
+ @test mt10. a == [10 , 20 ]
123
+ @test mt10. b == [3 , 4 ]
124
+ @test mt10. c == [50 , 60 ]
125
+ st = Optimisers. setup (Momentum (0.1 , 0.9 ), mt)
126
+ gt = gradient (m -> sum (abs2, m. a) + 100 sum (abs2, m. b), mt)
127
+ _, mtup = Optimisers. update (st, mt, gt... )
128
+ @test mtup. a ≈ [0.8 , 1.6 ]
129
+ @test mtup. b == [3 , 4 ]
130
+ @test mtup. c == [5 , 6 ]
131
+
132
+ # Various kinds of missing branches together:
133
+ m = Foo (
134
+ TwoThirds (Foo (1.0 , Float32[2 ,3 ,4 ]), 5.0 , Float32[6 ,7 ]),
135
+ TwoThirds ((p = Float32[1 ,2 ,3 ],), sin, (q = 4.0 , r = cos,)),
136
+ )
137
+ s = Optimisers. setup (Momentum (0.1 , 0.9 ), m)
138
+ g = gradient (m -> sum (abs2, m. x. a. y) + m. x. b^ 2 + log (m. y. c. q), m)
139
+ @test Optimisers. update! (s, m, g... )[2 ] isa Foo
140
+ end
141
+
102
142
@testset " broadcasting macro" begin
103
143
x = [1.0 , 2.0 ]; y = [3 ,4 ]; z = [5 ,6 ]
104
144
@test (@. . x + y * z) isa Broadcast. Broadcasted
0 commit comments