Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/IntervalMDP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ export OMaximization, LPMcCormickRelaxation, VertexEnumeration
export RobustValueIteration
export default_algorithm, default_bellman_algorithm, bellman_algorithm

include("value.jl")
include("utils.jl")
include("threading.jl")
include("workspace.jl")
include("strategy_cache.jl")
include("bellman.jl")
include("value.jl")
include("state_sampling.jl")

include("robust_value_iteration.jl")
Expand Down
19 changes: 18 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,25 @@ bellman_algorithm(alg::RobustValueIteration) = alg.bellman_alg
############################

# TODO: Provide implementation for this algorithm. When provided, consider changing the default algorithm.
struct IntervalValueIteration <: ModelCheckingAlgorithm end
struct IntervalValueIteration{B <: BellmanAlgorithm} <: ModelCheckingAlgorithm
bellman_alg::B
end
bellman_algorithm(alg::IntervalValueIteration) = alg.bellman_alg

struct BoundedRealTimeDynamicProgramming{B <: BellmanAlgorithm} <: ModelCheckingAlgorithm
bellman_alg::B
end
bellman_algorithm(alg::BoundedRealTimeDynamicProgramming) = alg.bellman_alg


struct GeneralizedSamplingbasedRobustDynamicProgramming{B <: BellmanAlgorithm} <: ModelCheckingAlgorithm
bellman_alg::B
end
bellman_algorithm(alg::GeneralizedSamplingbasedRobustDynamicProgramming) = alg.bellman_alg

###############################
# Topological Value Iteration #
###############################
# TODO: Consider topological value iteration as an alternative algorithm (infinite time only).

##### Default algorithm for solving Interval MDP problems
Expand Down
26 changes: 15 additions & 11 deletions src/robust_value_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ end
termination_criteria(prop, finitetime::Val{false}) =
CovergenceCriteria(convergence_eps(prop))

function initialize!(value_function::ValueFunction, prop::AbstractReachability)
initialize!(value_function, prop, Val(isupper(value_function)))
end

termination_criteria(::RobustValueIteration, spec) = termination_criteria(spec)

"""
solve(problem::AbstractIntervalMDPProblem, alg::RobustValueIteration; callback=nothing)

Expand Down Expand Up @@ -167,21 +173,22 @@ function solve(problem::ControlSynthesisProblem, alg::RobustValueIteration; kwar
return ControlSynthesisSolution(strategy, V, res, k)
end

function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback = nothing)
function _value_iteration!(problem::AbstractIntervalMDPProblem, alg::RobustValueIteration; callback = nothing)
mp = system(problem)
spec = specification(problem)
term_criteria = termination_criteria(spec)
term_criteria = termination_criteria(alg, spec)

# It is more efficient to use allocate first and reuse across iterations
workspace = construct_workspace(mp, bellman_algorithm(alg))
strategy_cache = construct_strategy_cache(problem)
sampling_strat = sampling_strategy(alg)

value_function = StateValueFunction(problem)
value_function = construct_value_function(alg, problem)
initialize!(value_function, spec)
nextiteration!(value_function)

bellman_update!(workspace, strategy_cache, sampling_strat, value_function, 0, mp, spec)
update_sequence = sample(sampling_strat, mp, select_strategy_cache(strategy_cache, 0))
bellman_update!(workspace, strategy_cache, update_sequence, value_function, 0, mp, spec)
k = 1

if !isnothing(callback)
Expand All @@ -191,7 +198,8 @@ function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback =
while !term_criteria(value_function.current, k, lastdiff!(value_function))
nextiteration!(value_function)

bellman_update!(workspace, strategy_cache, sampling_strat, value_function, k, mp, spec)
update_sequence = sample(sampling_strat, mp, select_strategy_cache(strategy_cache, k))
bellman_update!(workspace, strategy_cache, update_sequence, value_function, k, mp, spec)
k += 1

if !isnothing(callback)
Expand All @@ -206,9 +214,7 @@ function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback =
return value_function.current, k, value_function.previous, strategy_cache
end

function bellman_update!(workspace, strategy_cache, sampling_strat::SamplingStrategy, value_function::StateValueFunction, k, mp, spec)

update_sequence = sample(sampling_strat, mp, select_strategy_cache(strategy_cache, k))
function bellman_update!(workspace, strategy_cache, update_sequence, value_function::StateValueFunction, k, mp, spec)

# 1. compute expectation for Q(s, a)
expectation!(
Expand Down Expand Up @@ -237,9 +243,7 @@ function bellman_update!(workspace, strategy_cache, sampling_strat::SamplingStra
step_postprocess_strategy_cache!(strategy_cache)
end

function bellman_update!(workspace, strategy_cache::NonOptimizingStrategyCache, sampling_strat::SamplingStrategy, value_function::StateValueFunction, k, mp, spec)

update_sequence = sample(sampling_strat, mp, select_strategy_cache(strategy_cache, k))
function bellman_update!(workspace, strategy_cache::NonOptimizingStrategyCache, update_sequence, value_function::StateValueFunction, k, mp, spec)

expectation!(
workspace,
Expand Down
16 changes: 16 additions & 0 deletions src/specification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,15 @@ function initialize!(value_function, prop::AbstractReachability)
@inbounds value_function.current[reach(prop)] .= 1.0
end

function initialize!(value_function, prop::AbstractReachability, upper::Val{false})
@inbounds value_function.current[reach(prop)] .= 1.0
end

function initialize!(value_function, prop::AbstractReachability, upper::Val{true})
value_function.current .= 1.0
@inbounds value_function.current[reach(prop)] .= 1.0
end

function step_postprocess_value_function!(value_function, prop::AbstractReachability)
@inbounds value_function.current[reach(prop)] .= 1.0
end
Expand Down Expand Up @@ -566,6 +575,13 @@ function checkdisjoint(reach, avoid)
end
end


function initialize!(value_function, prop::AbstractReachAvoid, upper::Val{true})
value_function.current .= 1.0
@inbounds value_function.current[reach(prop)] .= 1.0
@inbounds value_function.current[avoid(prop)] .= -1.0
end

"""
FiniteTimeReachAvoid{VT <: Vector{Union{<:Integer, <:Tuple, <:CartesianIndex}}}, T <: Integer}

Expand Down
85 changes: 82 additions & 3 deletions src/value.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
@enum IntervalMode Upper Lower
isupper(mode::IntervalMode) = mode == Upper
islower(mode::IntervalMode) = mode == Lower

abstract type ValueFunction end

struct StateValueFunction{R, A1 <: AbstractArray{R}, A2 <: AbstractArray{R}} <: ValueFunction
previous::A1
current::A1
intermediate_state_action_value::A2
interval::IntervalMode
end

function StateValueFunction(problem::AbstractIntervalMDPProblem)
Expand All @@ -19,7 +24,23 @@ function StateValueFunction(problem::AbstractIntervalMDPProblem)
intermediate_state_action_value = arrayfactory(mp, valuetype(mp), dim)
intermediate_state_action_value .= zero(valuetype(mp))

return StateValueFunction(previous, current, intermediate_state_action_value)
return StateValueFunction(previous, current, intermediate_state_action_value, Lower)
end

function StateValueFunction(problem::AbstractIntervalMDPProblem, mode::IntervalMode)
mp = system(problem)
previous = arrayfactory(mp, valuetype(mp), state_values(mp))
previous .= zero(valuetype(mp))
current = copy(previous)

dim = (action_values(mp)..., state_values(mp)...)
# concat gives shape: (a1, a2) , (s1, s2) => (a1, a2, s1, s2)
# (a, s) to access a more frequently due to column major
# TODO: works for IMDP, need to check for fIMDP
intermediate_state_action_value = arrayfactory(mp, valuetype(mp), dim)
intermediate_state_action_value .= zero(valuetype(mp))

return StateValueFunction(previous, current, intermediate_state_action_value, mode)
end

function lastdiff!(V::StateValueFunction{R}) where {R}
Expand All @@ -36,11 +57,15 @@ function nextiteration!(V::StateValueFunction)
return V
end

islower(V::StateValueFunction) = islower(V.interval)
isupper(V::StateValueFunction) = isupper(V.interval)


struct StateActionValueFunction{R, A1 <: AbstractArray{R}, A2 <: AbstractArray{R}} <: ValueFunction
previous::A1
current::A1
intermediate_state_value::A2
interval::IntervalMode
end

function StateActionValueFunction(problem::AbstractIntervalMDPProblem)
Expand All @@ -54,7 +79,21 @@ function StateActionValueFunction(problem::AbstractIntervalMDPProblem)
intermediate_state_value = arrayfactory(mp, valuetype(mp), state_values(mp))
intermediate_state_value .= zero(valuetype(mp))

return StateActionValueFunction(previous, current, intermediate_state_value)
return StateActionValueFunction(previous, current, intermediate_state_value, Lower)
end

function StateActionValueFunction(problem::AbstractIntervalMDPProblem, mode::IntervalMode)
mp = system(problem)
dim = (action_values(mp)..., state_values(mp)...)
# TODO: works for IMDP, need to check for fIMDP
previous = arrayfactory(mp, valuetype(mp), dim)
previous .= zero(valuetype(mp))
current = copy(previous)

intermediate_state_value = arrayfactory(mp, valuetype(mp), state_values(mp))
intermediate_state_value .= zero(valuetype(mp))

return StateActionValueFunction(previous, current, intermediate_state_value, mode)
end


Expand All @@ -70,4 +109,44 @@ function nextiteration!(V::StateActionValueFunction)
copy!(V.previous, V.current)

return V
end
end

islower(V::StateActionValueFunction) = islower(V.interval)
isupper(V::StateActionValueFunction) = isupper(V.interval)


struct IntervalValueFunction{V <: ValueFunction} <: ValueFunction
lower::V
upper::V
end

lower(V::IntervalValueFunction) = V.lower
upper(V::IntervalValueFunction) = V.upper

function lastdiff!(V::IntervalValueFunction)
return (lastdiff!(V.lower), lastdiff!(V.upper))
end

function nextiteration!(V::IntervalValueFunction)
nextiteration!(V.lower)
nextiteration!(V.upper)

return V
end

function gap(V::IntervalValueFunction)
return abs.(V.lower.current .- V.upper.current)
end

function initialize!(value_function::IntervalValueFunction, prop::AbstractReachability)
initialize!(value_function.lower, prop, Val(false))
initialize!(value_function.upper, prop, Val(true))
end


#################
# Algorithms #
#################
construct_value_function(::RobustValueIteration, problem) = StateValueFunction(problem)
construct_value_function(::IntervalValueIteration, problem) = IntervalValueFunction(lower=StateValueFunction(problem), upper=StateValueFunction(problem))
construct_value_function(::GeneralizedSamplingbasedRobustDynamicProgramming, problem) = IntervalValueFunction(lower=StateActionValueFunction(problem), upper=StateActionValueFunction(problem))
Loading