Skip to content

Commit 52c30d8

Browse files
committed
change crying_baby
1 parent 0d30ee1 commit 52c30d8

File tree

2 files changed

+9
-47
lines changed

2 files changed

+9
-47
lines changed

src/pomdp/crying_baby.jl

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ function DiscretePOMDP(pomdp::BabyPOMDP; γ::Float64=pomdp.γ)
121121
T[s_h, a_i, :] = [0.0, 1.0]
122122
T[s_h, a_s, :] = [0.0, 1.0]
123123

124-
R[s_s, a_f, :] = reward(pomdp, s_s, a_f)
125-
R[s_s, a_i, :] = reward(pomdp, s_s, a_i)
126-
R[s_s, a_s, :] = reward(pomdp, s_s, a_s)
127-
R[s_h, a_f, :] = reward(pomdp, s_h, a_f)
128-
R[s_h, a_i, :] = reward(pomdp, s_h, a_i)
129-
R[s_h, a_s, :] = reward(pomdp, s_h, a_s)
124+
R[s_s, a_f] = reward(pomdp, s_s, a_f)
125+
R[s_s, a_i] = reward(pomdp, s_s, a_i)
126+
R[s_s, a_s] = reward(pomdp, s_s, a_s)
127+
R[s_h, a_f] = reward(pomdp, s_h, a_f)
128+
R[s_h, a_i] = reward(pomdp, s_h, a_i)
129+
R[s_h, a_s] = reward(pomdp, s_h, a_s)
130130

131131
O[a_f, s_s, :] = [observation(pomdp, a_f, s_s).p, 1 - observation(pomdp, a_f, s_s).p]
132132
O[a_f, s_h, :] = [observation(pomdp, a_f, s_h).p, 1 - observation(pomdp, a_f, s_h).p]
@@ -139,44 +139,6 @@ function DiscretePOMDP(pomdp::BabyPOMDP; γ::Float64=pomdp.γ)
139139
end
140140

141141
function POMDP(pomdp::BabyPOMDP; γ::Float64=pomdp.γ)
142-
nS = n_states(pomdp)
143-
nA = n_actions(pomdp)
144-
nO = n_observations(pomdp)
145-
146-
T = zeros(nS, nA, nS)
147-
R = Array{Float64}(undef, nS, nA)
148-
O = Array{Float64}(undef, nA, nS, nO)
149-
150-
s_s = 1
151-
s_h = 2
152-
153-
a_f = 1
154-
a_i = 2
155-
a_s = 3
156-
157-
o_c = 1
158-
o_q = 2
159-
160-
T[s_s, a_f, :] = [1.0, 0.0]
161-
T[s_s, a_i, :] = [1.0-pomdp.p_become_hungry, pomdp.p_become_hungry]
162-
T[s_s, a_s, :] = [1.0-pomdp.p_become_hungry, pomdp.p_become_hungry]
163-
T[s_h, a_f, :] = [1.0, 0.0]
164-
T[s_h, a_i, :] = [0.0, 1.0]
165-
T[s_h, a_s, :] = [0.0, 1.0]
166-
167-
R[s_s, a_f, :] = reward(pomdp, s_s, a_f)
168-
R[s_s, a_i, :] = reward(pomdp, s_s, a_i)
169-
R[s_s, a_s, :] = reward(pomdp, s_s, a_s)
170-
R[s_h, a_f, :] = reward(pomdp, s_h, a_f)
171-
R[s_h, a_i, :] = reward(pomdp, s_h, a_i)
172-
R[s_h, a_s, :] = reward(pomdp, s_h, a_s)
173-
174-
O[a_f, s_s, :] = [observation(pomdp, a_f, s_s).p, 1 - observation(pomdp, a_f, s_s).p]
175-
O[a_f, s_h, :] = [observation(pomdp, a_f, s_h).p, 1 - observation(pomdp, a_f, s_h).p]
176-
O[a_i, s_s, :] = [observation(pomdp, a_i, s_s).p, 1 - observation(pomdp, a_i, s_s).p]
177-
O[a_i, s_h, :] = [observation(pomdp, a_i, s_h).p, 1 - observation(pomdp, a_i, s_h).p]
178-
O[a_s, s_s, :] = [observation(pomdp, a_s, s_s).p, 1 - observation(pomdp, a_s, s_s).p]
179-
O[a_s, s_h, :] = [observation(pomdp, a_s, s_h).p, 1 - observation(pomdp, a_s, s_h).p]
180-
181-
return POMDP(DiscretePOMDP(T, R, O, γ))
142+
disc_pomdp = DiscretePOMDP(pomdp)
143+
return POMDP(disc_pomdp)
182144
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ end
9191
@test 0 <= p.observation(m, rand(1:3), rand(1:2)).p <= 1
9292
@test p.reward(m, rand(1:2), rand(1:3)) <= 0
9393
@test p.reward(m, [0.1, 0.9], rand(1:3)) <= 0
94-
pomdp = p.POMDP(m)
94+
pomdp = p.DiscretePOMDP(m)
9595
end
9696

9797
@testset "machine_replacement.jl" begin

0 commit comments

Comments
 (0)