Skip to content

Commit 0d30ee1

Browse files
committed
edit pomdp models
1 parent 276ce43 commit 0d30ee1

File tree

9 files changed

+206
-57
lines changed

9 files changed

+206
-57
lines changed

docs/src/mdp.md

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# Markov Decision Process
1+
# MDP Usage
22

3-
## MDP Struct
3+
## MDP
44
The MDP struct gives the following:
55
- `γ`: discount factor
66
- `𝒮`: state space
@@ -9,53 +9,4 @@ The MDP struct gives the following:
99
- `R`: reward function
1010
- `TR`: function allows us to sample transition and reward
1111

12-
## DiscreteMDP Struct
13-
The DiscreteMDP struct gives the following objects and methods:
14-
- `ordered_states(m::DiscreteMDP)`: gives a vector of states
15-
- `ordered_actions(m::DiscreteMDP)`: gives a vector of actions
16-
- `T`: Matrix of transition function T(s,a,s′)
17-
- `transition(m::DiscreteMDP, s::Int, a::Int)`: function that gives a distribution of the transition
18-
- `generate_s(m::DiscreteMDP, s::Int, a::Int)`: function that samples the state from a transition
19-
- `R`: Matrix of reward values R(s,a) = ∑_s' R(s,a,s')*T(s,a,s′)
20-
- `reward(m::DiscreteMDP, s::Int, a::Int)`: function gives the reward of a state and action pair
21-
- `γ`: Discount factor
22-
23-
## Cart Pole, Mountain Car, Simple LQR
24-
These problems all have similar usage documentation. To build an instance of one of these problems run
25-
```julia
26-
m = Problem()
27-
mdp = MDP(m)
28-
```
29-
where `Problem` is either replaced with `CartPole`, `MountainCar`, or `LqrMDP`. Then `mdp` is a MDP struct so we get access to all of the functions describe in the MDP Struct section.
30-
31-
## Hex World
32-
For Hex World, you use the DiscreteMDP struct. You can either set up the HexWorld manually by calling
33-
```julia
34-
m = HexWorldMDP(hexes, HexWorldRBumpBorder, HexWorldPIntended, special_hex_rewards, HexWorldDiscountFactor)
35-
```
36-
where `HexWorldRBumpBorder`, `HexWorldPIntended` and `HexWorldDiscountFactor` are constants, `hexes` is a list of 2 dimensional coordinates and `special_hex_rewards` is a dictionary of all nonzero rewards.
37-
It is also possible to use one of the preset MDPs:
38-
```julia
39-
m = HexWorld
40-
m = StraightLineHexWorld
41-
```
42-
Then running
43-
```julia
44-
mdp = m.mdp
45-
```
46-
gives an instance of a DiscreteMDP struct.
47-
48-
## Collision Avoidance
49-
To create an instance of the problem, run
50-
```julia
51-
m = CollisionAvoidanceMDP()
52-
```
53-
Each of the state are instances of the struct `CollissionAvoidanceMDPState` that have the objects
54-
- `h`: vertical separation
55-
- `dh`: rate of change in h
56-
- `a_prev`: last action
57-
- `τ`: horizontal time separation
58-
Then you the CollisionAvoidanceMDP struct has the methods:
59-
- `transition(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState, a::Float64)`: returns a distribution of states which can be sampled
60-
- `is_terminal(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState)`: determines if the state is terminal
61-
- `reward(𝒫::CollisionAvoidanceMDP, s::CollisionAvoidanceMDPState, a::Float64)`: gives the reward
12+
The function `T` takes in a state `s` and an action `a` and returns a distribution of states which can be sampled. The reward function `R` takes in a state `s` and action `a` and returns an reward. Finally `TR` takes in a state `s` and an action `a` and returns a tuple `(s', r)` where `s'` is the new state sampled from the transition function and `r` is the reward.

docs/src/pomdp.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# POMDP Usage
2+
3+
## POMDP
4+
The MDP struct gives the following:
5+
- `γ`: discount factor
6+
- `𝒮`: state space
7+
- `𝒜`: action space
8+
- `𝒪`: observation space
9+
- `T`: transition function
10+
- `R`: reward function
11+
- `O`: observation function
12+
- `TRO`: function that allows us to sample transition, reward, and observation
13+
14+
The function `T` takes in a state `s` and an action `a` and returns a distribution of possible states. The reward function `R` takes in a state `s` and action `a` and returns an reward. The observation function `O` takes in a state `s` and an action `a` and returns a distribution of possible observations. Finally `TRO` takes in a state `s` and an action `a` and returns a tuple `(s', r, o)` where `s'` is the new state sampled from the transition function, `r` is the reward and `o` is an observation sampled from the observation function.

src/mdp/discrete_mdp.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,25 @@ function generate_s(mdp::DiscreteMDP, s::Int, a::Int)
2424
end
2525
return s′
2626
end
27-
reward(mdp::DiscreteMDP, s::Int, a::Int) = mdp.R[s,a]
27+
reward(mdp::DiscreteMDP, s::Int, a::Int) = mdp.R[s,a]
28+
29+
function MDP(mdp::DiscreteMDP; γ::Float64=discount(mdp))
30+
return MDP(
31+
γ,
32+
ordered_states(mdp),
33+
ordered_actions(mdp),
34+
(s,a, s′=nothing) -> begin
35+
S′ = transition(mdp, s, a)
36+
if s′ == nothing
37+
return S′
38+
end
39+
return pdf(S′, s′)
40+
end,
41+
(s,a) -> reward(mdp, s, a),
42+
(s, a)->begin
43+
s′ = rand(transition(mdp,s,a))
44+
r = reward(mdp, s, a)
45+
return (s′, r)
46+
end
47+
)
48+
end

src/mdp/hexworld.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,6 @@ end
174174
function DiscreteMDP(mdp::HexWorldMDP)
175175
return mdp.mdp
176176
end
177+
function MDP(mdp::HexWorldMDP)
178+
return MDP(mdp.mdp)
179+
end

src/pomdp/catch.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,46 @@ function DiscretePOMDP(mdp::Catch; γ::Float64=mdp.γ)
4444
return DiscretePOMDP(T, R, O, γ)
4545
end
4646

47+
function POMDP(mdp::Catch; γ::Float64=mdp.γ)
48+
Θ = [20,40,60,80] # proficiencies
49+
𝒜 = collect(10:10:100) # throw distances
50+
51+
nS = length(Θ)
52+
nA = length(𝒜)
53+
nO = 2 # catch or not
54+
55+
T = zeros(nS, nA, nS)
56+
R = Array{Float64}(undef, nS, nA)
57+
O = Array{Float64}(undef, nA, nS, nO)
58+
59+
o_catch = 1
60+
o_drop = 2
61+
62+
prob_catch(d,θ) = 1 - 1/(1+exp(-(d-θ)/15))
63+
64+
# Transition dynamics are 100% stationary.
65+
for si in 1:nS
66+
for ai in 1:nA
67+
T[si, ai, si] = 1.0
68+
end
69+
end
70+
71+
# Reward equal to distance caught
72+
for (si,θ) in enumerate(Θ)
73+
for (ai,d) in enumerate(𝒜)
74+
R[si,ai] = d*prob_catch(d,θ) # distance caught times prob of catch
75+
end
76+
end
77+
78+
# Observation is based on whether we caught or not.
79+
for (ai,d) in enumerate(𝒜)
80+
for (si′,θ) in enumerate(Θ)
81+
O[ai,si′,o_catch] = prob_catch(d,θ)
82+
O[ai,si′,o_drop] = 1 - O[ai,si′,o_catch]
83+
end
84+
end
85+
86+
return POMDP(DiscretePOMDP(T, R, O, γ))
87+
end
88+
4789
# const Catch = generate_catch_pomdp(0.9)

src/pomdp/crying_baby.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,46 @@ function DiscretePOMDP(pomdp::BabyPOMDP; γ::Float64=pomdp.γ)
137137

138138
return DiscretePOMDP(T, R, O, γ)
139139
end
140+
141+
function POMDP(pomdp::BabyPOMDP; γ::Float64=pomdp.γ)
142+
nS = n_states(pomdp)
143+
nA = n_actions(pomdp)
144+
nO = n_observations(pomdp)
145+
146+
T = zeros(nS, nA, nS)
147+
R = Array{Float64}(undef, nS, nA)
148+
O = Array{Float64}(undef, nA, nS, nO)
149+
150+
s_s = 1
151+
s_h = 2
152+
153+
a_f = 1
154+
a_i = 2
155+
a_s = 3
156+
157+
o_c = 1
158+
o_q = 2
159+
160+
T[s_s, a_f, :] = [1.0, 0.0]
161+
T[s_s, a_i, :] = [1.0-pomdp.p_become_hungry, pomdp.p_become_hungry]
162+
T[s_s, a_s, :] = [1.0-pomdp.p_become_hungry, pomdp.p_become_hungry]
163+
T[s_h, a_f, :] = [1.0, 0.0]
164+
T[s_h, a_i, :] = [0.0, 1.0]
165+
T[s_h, a_s, :] = [0.0, 1.0]
166+
167+
R[s_s, a_f, :] = reward(pomdp, s_s, a_f)
168+
R[s_s, a_i, :] = reward(pomdp, s_s, a_i)
169+
R[s_s, a_s, :] = reward(pomdp, s_s, a_s)
170+
R[s_h, a_f, :] = reward(pomdp, s_h, a_f)
171+
R[s_h, a_i, :] = reward(pomdp, s_h, a_i)
172+
R[s_h, a_s, :] = reward(pomdp, s_h, a_s)
173+
174+
O[a_f, s_s, :] = [observation(pomdp, a_f, s_s).p, 1 - observation(pomdp, a_f, s_s).p]
175+
O[a_f, s_h, :] = [observation(pomdp, a_f, s_h).p, 1 - observation(pomdp, a_f, s_h).p]
176+
O[a_i, s_s, :] = [observation(pomdp, a_i, s_s).p, 1 - observation(pomdp, a_i, s_s).p]
177+
O[a_i, s_h, :] = [observation(pomdp, a_i, s_h).p, 1 - observation(pomdp, a_i, s_h).p]
178+
O[a_s, s_s, :] = [observation(pomdp, a_s, s_s).p, 1 - observation(pomdp, a_s, s_s).p]
179+
O[a_s, s_h, :] = [observation(pomdp, a_s, s_h).p, 1 - observation(pomdp, a_s, s_h).p]
180+
181+
return POMDP(DiscretePOMDP(T, R, O, γ))
182+
end

src/pomdp/discrete_pomdp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ reward(pomdp::DiscretePOMDP, s::Int, a::Int) = pomdp.R[s,a]
2121

2222
reward(pomdp::DiscretePOMDP, b::Vector{Float64}, a::Int) = sum(reward(pomdp,s,a)*b[s] for s in ordered_states(pomdp))
2323

24-
function POMDP(pomdp; γ::Float64=discount(pomdp))
24+
function POMDP(pomdp::DiscretePOMDP; γ::Float64=discount(pomdp))
2525
return POMDP(
2626
γ,
2727
ordered_states(pomdp),

src/pomdp/machine_replacement.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,78 @@ function DiscretePOMDP(pomdp::MachineReplacement; γ::Float64=pomdp.γ)
9090
return DiscretePOMDP(T, R, O, γ)
9191
end
9292

93+
function POMDP(pomdp::MachineReplacement; γ::Float64=pomdp.γ)
94+
T = Array{Float64}(undef, 3, 4, 3)
95+
R = Array{Float64}(undef, 3, 4)
96+
O = Array{Float64}(undef, 4, 3, 2)
97+
98+
s_0 = 1 # none broken
99+
s_1 = 2 # one broken
100+
s_2 = 3 # two broken
101+
102+
a_m = 1 # manufacture
103+
a_e = 2 # manufacture + examine
104+
a_i = 3 # interrupt the line, inspect components, replace failed components
105+
a_r = 4 # interrupt the line, replace both components
106+
107+
o_n = 1 # nondefective
108+
o_d = 2 # defective
109+
110+
T[s_0, a_m, :] = [0.81, 0.18, 0.01] # 10% independent chance of part breaking
111+
T[s_0, a_e, :] = [0.81, 0.18, 0.01]
112+
T[s_0, a_i, :] = [1.00, 0.00, 0.00]
113+
T[s_0, a_r, :] = [1.00, 0.00, 0.00]
114+
T[s_1, a_m, :] = [0.00, 0.90, 0.10] # 10% chance of remaining part breaking
115+
T[s_1, a_e, :] = [0.00, 0.90, 0.10]
116+
T[s_1, a_i, :] = [1.00, 0.00, 0.00]
117+
T[s_1, a_r, :] = [1.00, 0.00, 0.00]
118+
T[s_2, a_m, :] = [0.00, 0.00, 1.00] # stay broken
119+
T[s_2, a_e, :] = [0.00, 0.00, 1.00]
120+
T[s_2, a_i, :] = [1.00, 0.00, 0.00]
121+
T[s_2, a_r, :] = [1.00, 0.00, 0.00]
122+
123+
# There is a profit of 1 for producing a nondefective product.
124+
# Thus, the expected profit for beginning with 0, 1, or 2 defective parts is
125+
# 0.9025, 0.475, and 0.25, respectively.
126+
# Examining the finished product costs 0.25.
127+
# The inspect action incurs a 0.5 penalty plus replacement cost for each unit of 1.
128+
# The straight-up replacement action has no inspection cost but does incur a 2 unit cost.
129+
r_examine = -0.25
130+
r_inspect = -0.5
131+
r_replace = -2.0
132+
133+
R[s_0, a_m] = 0.9025
134+
R[s_1, a_m] = 0.475
135+
R[s_2, a_m] = 0.25
136+
R[s_0, a_e] = 0.9025 + r_examine
137+
R[s_1, a_e] = 0.475 + r_examine
138+
R[s_2, a_e] = 0.25 + r_examine
139+
R[s_0, a_i] = r_inspect
140+
R[s_1, a_i] = r_inspect - 1.0 # replace 1
141+
R[s_2, a_i] = r_inspect - 2.0 # replace 2
142+
R[s_0, a_r] = r_replace
143+
R[s_1, a_r] = r_replace
144+
R[s_2, a_r] = r_replace
145+
146+
# Probabilities of observing a nondefective product are 1.0, 0.5, and 0.25 if
147+
# there are 0, 1, or 2 faulty internal components.
148+
# If we don't examine, we always observe nondefective.
149+
O[a_m, s_0, :] = [1.00, 0.00]
150+
O[a_m, s_1, :] = [1.00, 0.00]
151+
O[a_m, s_2, :] = [1.00, 0.00]
152+
O[a_e, s_0, :] = [1.00, 0.00]
153+
O[a_e, s_1, :] = [0.50, 0.50]
154+
O[a_e, s_2, :] = [0.25, 0.75]
155+
O[a_i, s_0, :] = [1.00, 0.00]
156+
O[a_i, s_1, :] = [1.00, 0.00]
157+
O[a_i, s_2, :] = [1.00, 0.00]
158+
O[a_r, s_0, :] = [1.00, 0.00]
159+
O[a_r, s_1, :] = [1.00, 0.00]
160+
O[a_r, s_2, :] = [1.00, 0.00]
161+
162+
return POMDP(DiscretePOMDP(T, R, O, γ))
163+
end
164+
93165
# MachineReplacement = generate_machine_replacement_pomdp(1.0)
94166

95167
MACHINE_REPLACEMENT_ACTION_COLORS = Dict(

test/runtests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ end
5555
@test p.generate_sr(m, state, action)[1] in p.ordered_states(m) && p.generate_sr(m, state, action)[2] <= 10
5656
@test p.generate_start_state(m) in p.ordered_states(m)
5757
@test p.hex_distance(rand(hexes), rand(hexes)) >= 0
58+
mdp = p.DiscreteMDP(m)
5859
end
5960
@testset "simple_lqr.jl" begin
6061
m = p.LqrMDP()
@@ -80,7 +81,8 @@ end
8081

8182

8283
@testset "crying_baby.jl" begin
83-
m = p.BabyPOMDP(-10.0, -5.0, -0.5, 0.1, 0.8, 0.1, 0.9, 0.9)
84+
# m = p.BabyPOMDP(-10.0, -5.0, -0.5, 0.1, 0.8, 0.1, 0.9, 0.9)
85+
m = p.BabyPOMDP()
8486
@test p.n_states(m) == 2 && p.ordered_states(m) == [1, 2]
8587
@test p.n_actions(m) == 3 && p.ordered_actions(m) == [1, 2, 3]
8688
@test p.n_observations(m) == 2 && p.ordered_observations(m) == [true, false]
@@ -89,12 +91,13 @@ end
8991
@test 0 <= p.observation(m, rand(1:3), rand(1:2)).p <= 1
9092
@test p.reward(m, rand(1:2), rand(1:3)) <= 0
9193
@test p.reward(m, [0.1, 0.9], rand(1:3)) <= 0
94+
pomdp = p.POMDP(m)
9295
end
9396

9497
@testset "machine_replacement.jl" begin
9598
# m = p.generate_machine_replacement_pomdp(1.0)
9699
mdp = p.MachineReplacement()
97-
m = p.MachineReplacement(mdp)
100+
m = p.DiscretePOMDP(mdp)
98101
@test p.n_states(m) == 3 && p.ordered_states(m) == 1:3
99102
@test p.n_actions(m) == 4 && p.ordered_actions(m) == 1:4
100103
@test p.n_observations(m) == 2 && p.ordered_observations(m) == 1:2
@@ -108,7 +111,7 @@ end
108111
@testset "catch.jl" begin
109112
# m = p.generate_catch_pomdp(0.9)
110113
mdp = p.Catch()
111-
m = p.DiscreteMDP(mdp)
114+
m = p.DiscretePOMDP(mdp)
112115
@test p.n_states(m) == 4 && p.ordered_states(m) == 1:4
113116
@test p.n_actions(m) == 10 && p.ordered_actions(m) == 1:10
114117
@test p.n_observations(m) == 2 && p.ordered_observations(m) == 1:2

0 commit comments

Comments
 (0)