Skip to content

Commit 3cc2d27

Browse files
committed
add unit tests for parameter optimization primitives
1 parent 65a41a9 commit 3cc2d27

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

test/optimization.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
@testset "optimization" begin
2+
3+
@testset "in_place_add!" begin
4+
# TODO
5+
end
6+
7+
@testset "Accumulator" begin
8+
# TODO
9+
end
10+
11+
@testset "Julia parameter store" begin
12+
13+
store = JuliaParameterStore()
14+
15+
@gen function foo()
16+
@param theta::Float64
17+
@param phi::Vector{Float64}
18+
end
19+
register_parameters!(foo, [:theta, :phi])
20+
21+
# before the parameters are initialized in the store
22+
23+
@test Gen.get_local_parameters(store, foo) == Dict{Symbol,Any}()
24+
25+
@test_throws KeyError get_gradient((foo, :theta), store)
26+
@test_throws KeyError get_parameter_value((foo, :theta), store)
27+
@test_throws KeyError increment_gradient!((foo, :theta), 1.0, store)
28+
@test_throws KeyError reset_gradient!((foo, :theta), store)
29+
@test_throws KeyError Gen.set_parameter_value!((foo, :theta), 1.0, store)
30+
@test_throws KeyError Gen.get_gradient_accumulator((foo, :theta), store)
31+
32+
@test_throws KeyError get_gradient((foo, :phi), store)
33+
@test_throws KeyError get_parameter_value((foo, :phi), store)
34+
@test_throws KeyError increment_gradient!((foo, :phi), [1.0, 1.0], store)
35+
@test_throws KeyError reset_gradient!((foo, :phi), store)
36+
@test_throws KeyError Gen.set_parameter_value!((foo, :phi), [1.0, 1.0], store)
37+
@test_throws KeyError Gen.get_gradient_accumulator((foo, :phi), store)
38+
39+
# after the parameters are initialized in the store
40+
41+
init_parameter!((foo, :theta), 1.0, store)
42+
init_parameter!((foo, :phi), [1.0, 2.0], store)
43+
44+
dict = Gen.get_local_parameters(store, foo)
45+
@test length(dict) == 2
46+
@test dict[:theta] == 1.0
47+
@test dict[:phi] == [1.0, 2.0]
48+
49+
@test get_gradient((foo, :theta), store) == 0.0
50+
@test get_parameter_value((foo, :theta), store) == 1.0
51+
increment_gradient!((foo, :theta), 1.1, store)
52+
@test get_gradient((foo, :theta), store) == 1.1
53+
increment_gradient!((foo, :theta), 1.1, 2.0, store)
54+
@test get_gradient((foo, :theta), store) == (1.1 + 2.2)
55+
reset_gradient!((foo, :theta), store)
56+
@test get_gradient((foo, :theta), store) == 0.0
57+
Gen.set_parameter_value!((foo, :theta), 2.0, store)
58+
@test get_parameter_value((foo, :theta), store) == 2.0
59+
@test get_value(Gen.get_gradient_accumulator((foo, :theta), store)) == 0.0
60+
61+
@test get_gradient((foo, :phi), store) == [0.0, 0.0]
62+
@test get_parameter_value((foo, :phi), store) == [1.0, 2.0]
63+
increment_gradient!((foo, :phi), [1.1, 1.2], store)
64+
@test get_gradient((foo, :phi), store) == [1.1, 1.2]
65+
increment_gradient!((foo, :phi), [1.1, 1.2], 2.0, store)
66+
@test get_gradient((foo, :phi), store) == ([1.1, 1.2] .+ (2.0 * [1.1, 1.2]))
67+
reset_gradient!((foo, :phi), store)
68+
@test get_gradient((foo, :phi), store) == [0.0, 0.0]
69+
Gen.set_parameter_value!((foo, :phi), [2.0, 3.0], store)
70+
@test get_parameter_value((foo, :phi), store) == [2.0, 3.0]
71+
@test Gen.get_value(Gen.get_gradient_accumulator((foo, :phi), store)) == [0.0, 0.0]
72+
73+
# FixedStepGradientDescent
74+
75+
# DecayStepGradientDescent
76+
77+
# init_optimizer and apply_update! for FixedStepGradientDescent and DecayStepGradientDescent
78+
# default_parameter_context and default_julia_parameter_store
79+
end
80+
81+
@testset "composite optimizer" begin
82+
83+
end
84+
85+
86+
end

0 commit comments

Comments
 (0)