Skip to content

Commit ef67105

Browse files
authored
Merge pull request #18 from yangyou95/master
Add convert_s function
2 parents f742a4a + 26e69f0 commit ef67105

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

src/LaserTag.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,23 @@ export
1919
LTState,
2020
CMeas,
2121
DMeas,
22-
LaserTagVis,
23-
24-
MoveTowards,
22+
LaserTagVis, MoveTowards,
2523
MoveTowardsSampled,
2624
OptimalMLSolver,
2725
OptimalML,
2826
BestExpectedSolver,
29-
BestExpected,
30-
31-
DESPOTEmu,
32-
33-
gen_lasertag,
27+
BestExpected, DESPOTEmu, gen_lasertag,
3428
cpp_emu_lasertag,
3529
tikz_pic,
3630
n_clear_cells
3731

3832

39-
const Coord = SVector{2, Int}
40-
const CMeas = MVector{8, Float64}
41-
const DMeas = MVector{8, Int}
33+
const Coord = SVector{2,Int}
34+
const CMeas = MVector{8,Float64}
35+
const DMeas = MVector{8,Int}
4236

43-
const C_SAME_LOC = fill!(MVector{8, Float64}(undef), -1.0)
44-
const D_SAME_LOC = fill!(MVector{8, Int64}(undef), -1)
37+
const C_SAME_LOC = fill!(MVector{8,Float64}(undef), -1.0)
38+
const D_SAME_LOC = fill!(MVector{8,Int64}(undef), -1)
4539

4640
@auto_hash_equals struct LTState # XXX auto_hash_equals isn't correct for terminal
4741
robot::Coord
@@ -71,16 +65,16 @@ obs_type(om::ObsModel) = obs_type(typeof(om))
7165

7266
include("distance_cache.jl")
7367

74-
@with_kw struct LaserTagPOMDP{M<:ObsModel, O<:Union{CMeas, DMeas}} <: POMDP{LTState, Int, O}
75-
tag_reward::Float64 = 10.0
76-
step_cost::Float64 = 1.0
77-
discount::Float64 = 0.95
78-
floor::Floor = Floor(7, 11)
79-
obstacles::Set{Coord} = Set{Coord}()
80-
robot_init::Union{Coord, Nothing} = nothing
81-
diag_actions::Bool = false
82-
dcache::LTDistanceCache = LTDistanceCache(floor, obstacles)
83-
obs_model::M = DESPOTEmu(floor, 2.5)
68+
@with_kw struct LaserTagPOMDP{M<:ObsModel,O<:Union{CMeas,DMeas}} <: POMDP{LTState,Int,O}
69+
tag_reward::Float64 = 10.0
70+
step_cost::Float64 = 1.0
71+
discount::Float64 = 0.95
72+
floor::Floor = Floor(7, 11)
73+
obstacles::Set{Coord} = Set{Coord}()
74+
robot_init::Union{Coord,Nothing} = nothing
75+
diag_actions::Bool = false
76+
dcache::LTDistanceCache = LTDistanceCache(floor, obstacles)
77+
obs_model::M = DESPOTEmu(floor, 2.5)
8478
end
8579

8680
ltfloor(m::LaserTagPOMDP) = m.floor
@@ -142,6 +136,16 @@ function POMDPs.reward(p::LaserTagPOMDP, s::LTState, a::Int, sp::LTState)
142136
end
143137
end
144138

139+
function POMDPs.convert_s(T::Type{<:AbstractArray}, s::LTState, p::LaserTagPOMDP)
140+
vals = SVector{5, Float64}(s.robot..., s.opponent..., s.terminal)
141+
return convert(T, vals)
142+
end
143+
144+
function POMDPs.convert_s(T::Type{LTState}, v::AbstractArray, p::LaserTagPOMDP)
145+
return LTState(Coord(v[1], v[2]), Coord(v[3], v[4]), v[5])
146+
end
147+
148+
145149
POMDPs.isterminal(p::LaserTagPOMDP, s::LTState) = s.terminal
146150
POMDPs.discount(p::LaserTagPOMDP) = p.discount
147151

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using POMDPTools
66
using ParticleFilters
77
using POMDPs
88
using StableRNGs
9+
using StaticArrays
910

1011
@time p = gen_lasertag()
1112

@@ -59,6 +60,17 @@ for dir in 1:8
5960
@test total == N
6061
end
6162

63+
@testset "convert_s" begin
64+
s_test = rand(rng, initialstate(p))
65+
66+
for VT in [Vector{Float64}, SVector]
67+
v_s_test = convert_s(VT, s_test, p)
68+
@test v_s_test isa VT
69+
s_back = convert_s(LTState, v_s_test, p)
70+
@test s_back == s_test
71+
end
72+
end
73+
6274
pol = RandomPolicy(p, rng=StableRNG(1))
6375

6476
sim = HistoryRecorder(max_steps=10, rng=StableRNG(2))

0 commit comments

Comments
 (0)