Skip to content

Commit 49b8440

Browse files
committed
cleanup, exports
1 parent 024723b commit 49b8440

File tree

4 files changed

+50
-58
lines changed

4 files changed

+50
-58
lines changed

src/DecisionMakingProblems.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,29 @@ using Parameters
99
using Statistics
1010
using Printf
1111

12+
export
13+
MDP, HexWorld, StraightLineHexWorld, TwentyFortyEight, CartPole, MountainCar, LQR, CollisionAvoidance,
14+
POMDP, DiscretePOMDP, CryingBaby, MachineReplacement, Catch,
15+
SimpleGame, PrisonersDilemma, RockPaperScissors, Travelers,
16+
MG, PredatorPreyHexWorld, CirclePredatorPreyHexWorld,
17+
POMG, MultiCaregiverCryingBaby,
18+
DecPOMDP, CollaborativePredatorPreyHexWorld, SimpleCollaborativePredatorPreyHexWorld, CircleCollaborativePredatorPreyHexWorld
19+
1220
import Base: <, ==, rand, vec
1321

1422
include("support_code.jl")
1523

16-
# include("search/search.jl")
17-
# include("search/hex_world.jl")
18-
1924
include("mdp/mdp.jl")
2025
include("mdp/discrete_mdp.jl")
21-
# include("mdp/sliding_tile_puzzle.jl")
22-
# include("mdp/gridworld.jl
2326
include("mdp/2048.jl")
2427
include("mdp/hexworld.jl")
25-
# include("mdp/you_get_what_you_bet.jl")
2628
include("mdp/simple_lqr.jl")
2729
include("mdp/cart_pole.jl")
2830
include("mdp/mountain_car.jl")
2931
include("mdp/collision_avoidance.jl")
3032

3133
include("pomdp/pomdp.jl")
3234
include("pomdp/discrete_pomdp.jl")
33-
# include("pomdp/spelunker_joe.jl")
3435
include("pomdp/crying_baby.jl")
3536
include("pomdp/machine_replacement.jl")
3637
include("pomdp/catch.jl")
@@ -49,6 +50,4 @@ include("pomg/multicaregiver.jl")
4950
include("decpomdp/decpomdp.jl")
5051
include("decpomdp/collab_predator_prey.jl")
5152

52-
53-
5453
end # module

src/mdp/collision_avoidance.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@with_kw struct CollisionAvoidanceMDP
1+
@with_kw struct CollisionAvoidance
22
ddh_max::Float64 = 1.0 # vertical acceleration limit [m/s²]
33
collision_threshold::Float64 = 50.0 # collision threshold [m]
44
reward_collision::Float64 = -1.0 # reward obtained if collision occurs
@@ -9,16 +9,16 @@
99
::SetCategorical{Float64} = SetCategorical([2.0, 0.0, -2.0], [0.25, 0.5, 0.25])
1010
end
1111

12-
struct CollisionAvoidanceMDPState
12+
struct CollisionAvoidanceState
1313
h::Float64 # vertical separation [m]
1414
dh::Float64 # rate of change of h [m/s]
1515
a_prev::Float64 # previous acceleration [m/s²]
1616
τ::Float64 # horizontal separation time [s]
1717
end
1818

19-
Base.vec(s::CollisionAvoidanceMDPState) = [s.h, s.dh, s.a_prev, s.τ]
19+
Base.vec(s::CollisionAvoidanceState) = [s.h, s.dh, s.a_prev, s.τ]
2020

21-
function transition(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState, a::Float64)
21+
function transition(𝒫::CollisionAvoidance, s::CollisionAvoidanceState, a::Float64)
2222
h = s.h + s.dh
2323
dh = s.dh
2424
if a != 0.0
@@ -31,14 +31,14 @@ function transition(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState,
3131
a_prev = a
3232
τ = max(s.τ - 1.0, -1.0)
3333
states = [
34-
CollisionAvoidanceMDPState(h, dh + ν, a_prev, τ) for ν in 𝒫..elements
34+
CollisionAvoidanceState(h, dh + ν, a_prev, τ) for ν in 𝒫..elements
3535
]
3636
return SetCategorical(states, 𝒫..distr.p)
3737
end
3838

39-
is_terminal(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState) = s.τ < 0.0
39+
is_terminal(𝒫::CollisionAvoidance, s::CollisionAvoidanceState) = s.τ < 0.0
4040

41-
function reward(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState, a::Float64)
41+
function reward(𝒫::CollisionAvoidance, s::CollisionAvoidanceState, a::Float64)
4242
r = 0.0
4343
if abs(s.h) < 𝒫.collision_threshold && abs(s.τ) < eps()
4444
# We collided
@@ -59,7 +59,7 @@ end
5959
end
6060

6161
function rand(b::CollisionAvoidanceStateDistribution)
62-
return CollisionAvoidanceMDPState(Distributions.rand(b.h), Distributions.rand(b.dh), b.a_prev, b.tau)
62+
return CollisionAvoidanceState(Distributions.rand(b.h), Distributions.rand(b.dh), b.a_prev, b.tau)
6363
end
6464

6565
@with_kw struct SimpleCollisionAvoidancePolicy
@@ -74,7 +74,7 @@ struct OptimalCollisionAvoidancePolicy
7474
Q
7575
end
7676

77-
function OptimalCollisionAvoidancePolicy(mdp = CollisionAvoidanceMDP())
77+
function OptimalCollisionAvoidancePolicy(mdp = CollisionAvoidance())
7878
𝒜 = mdp.𝒜
7979

8080
hs = range(-200, 200, length=21) # discretization of h in m
@@ -85,7 +85,7 @@ function OptimalCollisionAvoidancePolicy(mdp = CollisionAvoidanceMDP())
8585
grid = GridInterpolations.RectangleGrid(hs, dhs, 𝒜, τs)
8686

8787
# State space
88-
𝒮 = [CollisionAvoidanceMDPState(h, dh, a_prev, τ) for h in hs, dh in dhs, a_prev in mdp.𝒜, τ in τs]
88+
𝒮 = [CollisionAvoidanceState(h, dh, a_prev, τ) for h in hs, dh in dhs, a_prev in mdp.𝒜, τ in τs]
8989

9090
# State value function
9191
U = zeros(length(𝒮))
@@ -105,7 +105,7 @@ function OptimalCollisionAvoidancePolicy(mdp = CollisionAvoidanceMDP())
105105
return OptimalCollisionAvoidancePolicy(mdp.𝒜, grid, Q)
106106
end
107107

108-
function action(policy::OptimalCollisionAvoidancePolicy, s::CollisionAvoidanceMDPState)
108+
function action(policy::OptimalCollisionAvoidancePolicy, s::CollisionAvoidanceState)
109109
vec_s = vec(s)
110110
a_best = first(policy.𝒜)
111111
q_best = -Inf
@@ -118,18 +118,18 @@ function action(policy::OptimalCollisionAvoidancePolicy, s::CollisionAvoidanceMD
118118
return a_best
119119
end
120120

121-
function (policy::OptimalCollisionAvoidancePolicy)(s::CollisionAvoidanceMDPState)
121+
function (policy::OptimalCollisionAvoidancePolicy)(s::CollisionAvoidanceState)
122122
return action(policy, s)
123123
end
124124

125-
function action(policy::SimpleCollisionAvoidancePolicy, s::CollisionAvoidanceMDPState)
125+
function action(policy::SimpleCollisionAvoidancePolicy, s::CollisionAvoidanceState)
126126
if abs(s.h) < policy.thresh_h && s.τ < policy.thresh_τ
127127
return (s.h > 0.0) ? policy.𝒜.up : policy.𝒜.down
128128
end
129129
return policy.𝒜.noalert
130130
end
131131

132-
function (policy::SimpleCollisionAvoidancePolicy)(s::CollisionAvoidanceMDPState)
132+
function (policy::SimpleCollisionAvoidancePolicy)(s::CollisionAvoidanceState)
133133
return action(policy, s)
134134
end
135135

@@ -139,7 +139,7 @@ struct CollisionAvoidanceValueFunction
139139
U
140140
end
141141

142-
function CollisionAvoidanceValueFunction(𝒫::CollisionAvoidanceMDP, policy)
142+
function CollisionAvoidanceValueFunction(𝒫::CollisionAvoidance, policy)
143143
𝒜 = 𝒫.𝒜
144144

145145
hs = range(-200, 200, length=21) # discretization of h in m
@@ -150,7 +150,7 @@ function CollisionAvoidanceValueFunction(𝒫::CollisionAvoidanceMDP, policy)
150150
grid = GridInterpolations.RectangleGrid(hs, dhs, 𝒜, τs)
151151

152152
# State space
153-
𝒮 = [CollisionAvoidanceMDPState(h, dh, a_prev, τ) for h in hs, dh in dhs, a_prev in 𝒫.𝒜, τ in τs]
153+
𝒮 = [CollisionAvoidanceState(h, dh, a_prev, τ) for h in hs, dh in dhs, a_prev in 𝒫.𝒜, τ in τs]
154154

155155
# State value function
156156
U = zeros(length(𝒮))
@@ -170,7 +170,7 @@ function (U::CollisionAvoidanceValueFunction)(s)
170170
return GridInterpolations.interpolate(U.grid, U.U, vec(s))
171171
end
172172

173-
function MDP(mdp::CollisionAvoidanceMDP; γ::Float64=1.0)
173+
function MDP(mdp::CollisionAvoidance; γ::Float64=1.0)
174174
return MDP(
175175
γ,
176176
nothing, # no ordered states

src/mdp/simple_lqr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ generate_start_state(mdp::LQR) = rand(Normal(0.3,0.1))
88

99
function transition(mdp::LQR, s::Float64, a::Float64)
1010
# NOTE: Truncated to prevent going off to infinity with poor policies
11-
return Truncated(Normal(s + a, 0.1), -10.0, 10.0)
11+
return truncated(Normal(s + a, 0.1), -10.0, 10.0)
1212
end
1313
reward(mdp::LQR, s::Float64, a::Float64) = -s^2
1414

test/runtests.jl

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
using DecisionMakingProblems
2-
# using PGFPlots
32
using Test
43
using Random
54
using LinearAlgebra
65
using GridInterpolations
76

8-
# @assert success(`lualatex -v`)
9-
# using NBInclude
10-
# @nbinclude(joinpath(dirname(@__FILE__), "..", "doc", "PGFPlots.ipynb"))
117
const p = DecisionMakingProblems
128

139
# MDP
1410

1511
@testset "2048.jl" begin
16-
m = p.TwentyFortyEight()
17-
mdp = p.MDP(m)
12+
m = TwentyFortyEight()
13+
mdp = MDP(m)
1814
@test length(mdp.𝒜) == 4
1915
@test mdp.γ == 1.0
2016
init_state = p.initial_board()
@@ -24,8 +20,8 @@ const p = DecisionMakingProblems
2420
end
2521

2622
@testset "cart_pole.jl" begin
27-
# m = p.CartPole(1.0, 10.0, 1.0, 1.0, 0.1, 9.8, 0.02, 4.8, deg2rad(24))
28-
m = p.CartPole()
23+
# m = CartPole(1.0, 10.0, 1.0, 1.0, 0.1, 9.8, 0.02, 4.8, deg2rad(24))
24+
m = CartPole()
2925
@test p.n_actions(m) == 2
3026
@test p.discount(m) == 1.0
3127
@test p.ordered_actions(m) == 1:2
@@ -36,11 +32,11 @@ end
3632
@test !p.is_terminal(m, state)
3733
@test min_state <= p.vec(p.cart_pole_transition(m, state, rand(1:2))) <= max_state
3834
@test p.reward(m, state, rand(1:2)) in [0.0, 1.0]
39-
mdp = p.MDP(m)
35+
mdp = MDP(m)
4036
end
4137

4238
@testset "collision_avoidance.jl" begin
43-
m = p.CollisionAvoidanceMDP()
39+
m = CollisionAvoidance()
4440
distrib = p.CollisionAvoidanceStateDistribution()
4541
s = p.rand(distrib)
4642
simple_pol = p.SimpleCollisionAvoidancePolicy()
@@ -49,11 +45,11 @@ end
4945
@test p.is_terminal(m, s) == (p.vec(s)[4] < 0.0)
5046
@test p.reward(m, rand(p.transition(m, s, optimal_pol(s))), rand(m.𝒜)) <= 0
5147
policy = p.CollisionAvoidanceValueFunction(m, simple_pol)
52-
mdp = p.MDP(m)
48+
mdp = MDP(m)
5349
end
5450

5551
@testset "hexworld.jl" begin
56-
m = p.HexWorld()
52+
m = HexWorld()
5753
hexes = m.hexes
5854
@test p.n_states(m) == length(hexes) + 1 && p.ordered_states(m) == 1:length(hexes) + 1
5955
@test p.n_actions(m) == 6 && p.ordered_actions(m) == 1:6
@@ -69,19 +65,19 @@ end
6965
@test p.generate_sr(m, state, action)[1] in p.ordered_states(m) && p.generate_sr(m, state, action)[2] <= 10
7066
@test p.generate_start_state(m) in p.ordered_states(m)
7167
@test p.hex_distance(rand(hexes), rand(hexes)) >= 0
72-
mdp = p.MDP(m)
68+
mdp = MDP(m)
7369
end
7470
@testset "simple_lqr.jl" begin
75-
m = p.LQR()
71+
m = LQR()
7672
@test p.discount(m) == 1.0
7773
state = p.generate_start_state(m)
7874
@test -10 <= rand(p.transition(m, state, rand())) <= 10
7975
@test p.reward(m, state, rand()) <= 0
80-
mdp = p.MDP(m)
76+
mdp = MDP(m)
8177
end
8278

8379
@testset "mountain_car.jl" begin
84-
m = p.MountainCar()
80+
m = MountainCar()
8581
@test p.n_actions(m) == 3 && p.ordered_actions(m) == [1, 2, 3]
8682
@test p.discount(m) == 1.0
8783
state_min = [-1.2, -0.07]
@@ -90,15 +86,14 @@ end
9086
@test all(state_min <= start_state <= state_max)
9187
@test all(state_min <= p.mountain_car_transition(start_state, 1) <= state_max)
9288
@test p.reward(m, start_state, 1) <= 0
93-
mdp = p.MDP(m)
89+
mdp = MDP(m)
9490
end
9591

9692

9793
# POMDP
9894

9995
@testset "crying_baby.jl" begin
100-
# m = p.CryingBaby(-10.0, -5.0, -0.5, 0.1, 0.8, 0.1, 0.9, 0.9)
101-
m = p.CryingBaby()
96+
m = CryingBaby()
10297
@test p.n_states(m) == 2 && p.ordered_states(m) == [1, 2]
10398
@test p.n_actions(m) == 3 && p.ordered_actions(m) == [1, 2, 3]
10499
@test p.n_observations(m) == 2 && p.ordered_observations(m) == [true, false]
@@ -111,9 +106,8 @@ end
111106
end
112107

113108
@testset "machine_replacement.jl" begin
114-
# m = p.generate_machine_replacement_pomdp(1.0)
115-
mdp = p.MachineReplacement()
116-
m = p.DiscretePOMDP(mdp)
109+
mdp = MachineReplacement()
110+
m = DiscretePOMDP(mdp)
117111
@test p.n_states(m) == 3 && p.ordered_states(m) == 1:3
118112
@test p.n_actions(m) == 4 && p.ordered_actions(m) == 1:4
119113
@test p.n_observations(m) == 2 && p.ordered_observations(m) == 1:2
@@ -125,9 +119,8 @@ end
125119
end
126120

127121
@testset "catch.jl" begin
128-
# m = p.generate_catch_pomdp(0.9)
129-
mdp = p.Catch()
130-
m = p.DiscretePOMDP(mdp)
122+
mdp = Catch()
123+
m = DiscretePOMDP(mdp)
131124
@test p.n_states(m) == 4 && p.ordered_states(m) == 1:4
132125
@test p.n_actions(m) == 10 && p.ordered_actions(m) == 1:10
133126
@test p.n_observations(m) == 2 && p.ordered_observations(m) == 1:2
@@ -142,7 +135,7 @@ end
142135
# Simple Game
143136

144137
@testset "prisoners_dilemma.jl" begin
145-
m = p.PrisonersDilemma()
138+
m = PrisonersDilemma()
146139
@test p.n_agents(m) == 2
147140
@test length(p.ordered_actions(m, rand(1:2))) == 2 && length(p.ordered_joint_actions(m)) == 4
148141
@test p.n_actions(m, rand(1:2)) == 2 && p.n_joint_actions(m) == 4
@@ -152,7 +145,7 @@ end
152145
end
153146

154147
@testset "rock_paper_scissors.jl" begin
155-
m = p.RockPaperScissors()
148+
m = RockPaperScissors()
156149
@test p.n_agents(m) == 2
157150
@test length(p.ordered_actions(m, rand(1:2))) == 3 && length(p.ordered_joint_actions(m)) == 9
158151
@test p.n_actions(m, rand(1:2)) == 3 && p.n_joint_actions(m) == 9
@@ -162,7 +155,7 @@ end
162155
end
163156

164157
@testset "travelers.jl" begin
165-
m = p.Travelers()
158+
m = Travelers()
166159
@test p.n_agents(m) == 2
167160
@test length(p.ordered_actions(m, rand(1:2))) == 99 && length(p.ordered_joint_actions(m)) == 99^2
168161
@test p.n_actions(m, rand(1:2)) == 99 && p.n_joint_actions(m) == 99^2
@@ -175,7 +168,7 @@ end
175168
# Markov Game
176169

177170
@testset "predator_prey.jl" begin
178-
m = p.PredatorPreyHexWorld()
171+
m = PredatorPreyHexWorld()
179172
hexes = m.hexes
180173
@test p.n_agents(m) == 2
181174
@test length(p.ordered_states(m, rand(1:2))) == length(hexes) && length(p.ordered_states(m)) == length(hexes)^2
@@ -191,7 +184,7 @@ end
191184
# POMG
192185

193186
@testset "multicaregiver.jl" begin
194-
m = p.MultiCaregiverCryingBaby()
187+
m = MultiCaregiverCryingBaby()
195188
@test p.n_agents(m) == 2
196189
@test length(p.ordered_states(m)) == 2
197190
@test length(p.ordered_actions(m, rand(1:2))) == 3 && length(p.ordered_joint_actions(m)) == 9
@@ -210,7 +203,7 @@ end
210203
# DecPOMDP
211204

212205
@testset "collab_predator_prey.jl" begin
213-
m = p.CollaborativePredatorPreyHexWorld()
206+
m = CollaborativePredatorPreyHexWorld()
214207
hexes = m.hexes
215208
@test p.n_agents(m) == 2
216209
@test length(p.ordered_states(m, rand(1:2))) == length(hexes) && length(p.ordered_states(m)) == length(hexes)^3

0 commit comments

Comments
 (0)