1
1
using DecisionMakingProblems
2
- # using PGFPlots
3
2
using Test
4
3
using Random
5
4
using LinearAlgebra
6
5
using GridInterpolations
7
6
8
- # @assert success(`lualatex -v`)
9
- # using NBInclude
10
- # @nbinclude(joinpath(dirname(@__FILE__), "..", "doc", "PGFPlots.ipynb"))
11
7
const p = DecisionMakingProblems
12
8
13
9
# MDP
14
10
15
11
@testset " 2048.jl" begin
16
- m = p . TwentyFortyEight ()
17
- mdp = p . MDP (m)
12
+ m = TwentyFortyEight ()
13
+ mdp = MDP (m)
18
14
@test length (mdp. 𝒜) == 4
19
15
@test mdp. γ == 1.0
20
16
init_state = p. initial_board ()
@@ -24,8 +20,8 @@ const p = DecisionMakingProblems
24
20
end
25
21
26
22
@testset " cart_pole.jl" begin
27
- # m = p. CartPole(1.0, 10.0, 1.0, 1.0, 0.1, 9.8, 0.02, 4.8, deg2rad(24))
28
- m = p . CartPole ()
23
+ # m = CartPole(1.0, 10.0, 1.0, 1.0, 0.1, 9.8, 0.02, 4.8, deg2rad(24))
24
+ m = CartPole ()
29
25
@test p. n_actions (m) == 2
30
26
@test p. discount (m) == 1.0
31
27
@test p. ordered_actions (m) == 1 : 2
36
32
@test ! p. is_terminal (m, state)
37
33
@test min_state <= p. vec (p. cart_pole_transition (m, state, rand (1 : 2 ))) <= max_state
38
34
@test p. reward (m, state, rand (1 : 2 )) in [0.0 , 1.0 ]
39
- mdp = p . MDP (m)
35
+ mdp = MDP (m)
40
36
end
41
37
42
38
@testset " collision_avoidance.jl" begin
43
- m = p . CollisionAvoidanceMDP ()
39
+ m = CollisionAvoidance ()
44
40
distrib = p. CollisionAvoidanceStateDistribution ()
45
41
s = p. rand (distrib)
46
42
simple_pol = p. SimpleCollisionAvoidancePolicy ()
49
45
@test p. is_terminal (m, s) == (p. vec (s)[4 ] < 0.0 )
50
46
@test p. reward (m, rand (p. transition (m, s, optimal_pol (s))), rand (m. 𝒜)) <= 0
51
47
policy = p. CollisionAvoidanceValueFunction (m, simple_pol)
52
- mdp = p . MDP (m)
48
+ mdp = MDP (m)
53
49
end
54
50
55
51
@testset " hexworld.jl" begin
56
- m = p . HexWorld ()
52
+ m = HexWorld ()
57
53
hexes = m. hexes
58
54
@test p. n_states (m) == length (hexes) + 1 && p. ordered_states (m) == 1 : length (hexes) + 1
59
55
@test p. n_actions (m) == 6 && p. ordered_actions (m) == 1 : 6
69
65
@test p. generate_sr (m, state, action)[1 ] in p. ordered_states (m) && p. generate_sr (m, state, action)[2 ] <= 10
70
66
@test p. generate_start_state (m) in p. ordered_states (m)
71
67
@test p. hex_distance (rand (hexes), rand (hexes)) >= 0
72
- mdp = p . MDP (m)
68
+ mdp = MDP (m)
73
69
end
74
70
@testset " simple_lqr.jl" begin
75
- m = p . LQR ()
71
+ m = LQR ()
76
72
@test p. discount (m) == 1.0
77
73
state = p. generate_start_state (m)
78
74
@test - 10 <= rand (p. transition (m, state, rand ())) <= 10
79
75
@test p. reward (m, state, rand ()) <= 0
80
- mdp = p . MDP (m)
76
+ mdp = MDP (m)
81
77
end
82
78
83
79
@testset " mountain_car.jl" begin
84
- m = p . MountainCar ()
80
+ m = MountainCar ()
85
81
@test p. n_actions (m) == 3 && p. ordered_actions (m) == [1 , 2 , 3 ]
86
82
@test p. discount (m) == 1.0
87
83
state_min = [- 1.2 , - 0.07 ]
90
86
@test all (state_min <= start_state <= state_max)
91
87
@test all (state_min <= p. mountain_car_transition (start_state, 1 ) <= state_max)
92
88
@test p. reward (m, start_state, 1 ) <= 0
93
- mdp = p . MDP (m)
89
+ mdp = MDP (m)
94
90
end
95
91
96
92
97
93
# POMDP
98
94
99
95
@testset " crying_baby.jl" begin
100
- # m = p.CryingBaby(-10.0, -5.0, -0.5, 0.1, 0.8, 0.1, 0.9, 0.9)
101
- m = p. CryingBaby ()
96
+ m = CryingBaby ()
102
97
@test p. n_states (m) == 2 && p. ordered_states (m) == [1 , 2 ]
103
98
@test p. n_actions (m) == 3 && p. ordered_actions (m) == [1 , 2 , 3 ]
104
99
@test p. n_observations (m) == 2 && p. ordered_observations (m) == [true , false ]
111
106
end
112
107
113
108
@testset " machine_replacement.jl" begin
114
- # m = p.generate_machine_replacement_pomdp(1.0)
115
- mdp = p. MachineReplacement ()
116
- m = p. DiscretePOMDP (mdp)
109
+ mdp = MachineReplacement ()
110
+ m = DiscretePOMDP (mdp)
117
111
@test p. n_states (m) == 3 && p. ordered_states (m) == 1 : 3
118
112
@test p. n_actions (m) == 4 && p. ordered_actions (m) == 1 : 4
119
113
@test p. n_observations (m) == 2 && p. ordered_observations (m) == 1 : 2
125
119
end
126
120
127
121
@testset " catch.jl" begin
128
- # m = p.generate_catch_pomdp(0.9)
129
- mdp = p. Catch ()
130
- m = p. DiscretePOMDP (mdp)
122
+ mdp = Catch ()
123
+ m = DiscretePOMDP (mdp)
131
124
@test p. n_states (m) == 4 && p. ordered_states (m) == 1 : 4
132
125
@test p. n_actions (m) == 10 && p. ordered_actions (m) == 1 : 10
133
126
@test p. n_observations (m) == 2 && p. ordered_observations (m) == 1 : 2
142
135
# Simple Game
143
136
144
137
@testset " prisoners_dilemma.jl" begin
145
- m = p . PrisonersDilemma ()
138
+ m = PrisonersDilemma ()
146
139
@test p. n_agents (m) == 2
147
140
@test length (p. ordered_actions (m, rand (1 : 2 ))) == 2 && length (p. ordered_joint_actions (m)) == 4
148
141
@test p. n_actions (m, rand (1 : 2 )) == 2 && p. n_joint_actions (m) == 4
152
145
end
153
146
154
147
@testset " rock_paper_scissors.jl" begin
155
- m = p . RockPaperScissors ()
148
+ m = RockPaperScissors ()
156
149
@test p. n_agents (m) == 2
157
150
@test length (p. ordered_actions (m, rand (1 : 2 ))) == 3 && length (p. ordered_joint_actions (m)) == 9
158
151
@test p. n_actions (m, rand (1 : 2 )) == 3 && p. n_joint_actions (m) == 9
162
155
end
163
156
164
157
@testset " travelers.jl" begin
165
- m = p . Travelers ()
158
+ m = Travelers ()
166
159
@test p. n_agents (m) == 2
167
160
@test length (p. ordered_actions (m, rand (1 : 2 ))) == 99 && length (p. ordered_joint_actions (m)) == 99 ^ 2
168
161
@test p. n_actions (m, rand (1 : 2 )) == 99 && p. n_joint_actions (m) == 99 ^ 2
175
168
# Markov Game
176
169
177
170
@testset " predator_prey.jl" begin
178
- m = p . PredatorPreyHexWorld ()
171
+ m = PredatorPreyHexWorld ()
179
172
hexes = m. hexes
180
173
@test p. n_agents (m) == 2
181
174
@test length (p. ordered_states (m, rand (1 : 2 ))) == length (hexes) && length (p. ordered_states (m)) == length (hexes)^ 2
191
184
# POMG
192
185
193
186
@testset " multicaregiver.jl" begin
194
- m = p . MultiCaregiverCryingBaby ()
187
+ m = MultiCaregiverCryingBaby ()
195
188
@test p. n_agents (m) == 2
196
189
@test length (p. ordered_states (m)) == 2
197
190
@test length (p. ordered_actions (m, rand (1 : 2 ))) == 3 && length (p. ordered_joint_actions (m)) == 9
210
203
# DecPOMDP
211
204
212
205
@testset " collab_predator_prey.jl" begin
213
- m = p . CollaborativePredatorPreyHexWorld ()
206
+ m = CollaborativePredatorPreyHexWorld ()
214
207
hexes = m. hexes
215
208
@test p. n_agents (m) == 2
216
209
@test length (p. ordered_states (m, rand (1 : 2 ))) == length (hexes) && length (p. ordered_states (m)) == length (hexes)^ 3
0 commit comments