Skip to content

Commit 4da8598

Browse files
authored
Merge pull request #95 from Zinoex:fm/terminal_dfa
Skip terminal states of DFA properties in bellman!
2 parents 08edf7a + d106e0c commit 4da8598

File tree

5 files changed

+50
-10
lines changed

5 files changed

+50
-10
lines changed

src/bellman.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ function bellman(
7878
alg::BellmanAlgorithm = default_bellman_algorithm(model);
7979
upper_bound = false,
8080
maximize = true,
81+
prop = nothing,
8182
)
8283
Vres = similar(V, source_shape(model))
8384

@@ -89,6 +90,7 @@ function bellman(
8990
alg;
9091
upper_bound = upper_bound,
9192
maximize = maximize,
93+
prop = prop,
9294
)
9395
end
9496

@@ -179,6 +181,7 @@ function bellman!(
179181
alg::BellmanAlgorithm = default_bellman_algorithm(model);
180182
upper_bound = false,
181183
maximize = true,
184+
prop = nothing,
182185
)
183186
workspace = construct_workspace(model, alg)
184187
strategy_cache = construct_strategy_cache(model)
@@ -192,6 +195,7 @@ function bellman!(
192195
avail_act;
193196
upper_bound = upper_bound,
194197
maximize = maximize,
198+
prop = prop,
195199
)
196200
end
197201

@@ -204,6 +208,7 @@ function bellman!(
204208
avail_act::AbstractAvailableActions = available_actions(model);
205209
upper_bound = false,
206210
maximize = true,
211+
prop = nothing,
207212
)
208213
return _bellman_helper!(
209214
workspace,
@@ -226,6 +231,7 @@ function bellman!(
226231
avail_act::AbstractAvailableActions = available_actions(model);
227232
upper_bound = false,
228233
maximize = true,
234+
prop = nothing,
229235
)
230236
mp = markov_process(model)
231237
lf = labelling_function(model)
@@ -242,6 +248,7 @@ function bellman!(
242248
avail_act;
243249
upper_bound = upper_bound,
244250
maximize = maximize,
251+
prop = prop,
245252
)
246253
end
247254

@@ -256,10 +263,16 @@ function _bellman_helper!(
256263
avail_act;
257264
upper_bound = false,
258265
maximize = true,
266+
prop = nothing,
259267
)
260268
W = workspace.intermediate_values
261269

262270
@inbounds for state in dfa
271+
# If a DFA property is given, skip terminal states
272+
if !isnothing(prop) && state terminal(prop)
273+
continue
274+
end
275+
263276
local_strategy_cache = localize_strategy_cache(strategy_cache, state)
264277

265278
# Select the value function for the current DFA state
@@ -296,10 +309,16 @@ function _bellman_helper!(
296309
avail_act;
297310
upper_bound = false,
298311
maximize = true,
312+
prop = nothing,
299313
) where {R}
300314
W = workspace.intermediate_values
301315

302316
@inbounds for state in dfa
317+
# If a DFA property is given, skip terminal states
318+
if !isnothing(prop) && state terminal(prop)
319+
continue
320+
end
321+
303322
local_strategy_cache = localize_strategy_cache(strategy_cache, state)
304323

305324
# Select the value function for the current DFA state

src/problem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct VerificationProblem{
3131
strategy::C,
3232
) where {S <: StochasticProcess, F <: Specification, C <: AbstractStrategy}
3333
checkspecification(spec, system, strategy)
34-
checkstrategy(strategy, system)
34+
checkstrategy(strategy, system, system_property(spec))
3535
return new{S, F, C}(system, spec, strategy)
3636
end
3737
end

src/robust_value_iteration.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback =
190190
mp;
191191
upper_bound = upper_bound,
192192
maximize = maximize,
193+
prop = system_property(spec),
193194
)
194195
step_postprocess_value_function!(value_function, spec)
195196
step_postprocess_strategy_cache!(strategy_cache)
@@ -210,6 +211,7 @@ function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback =
210211
mp;
211212
upper_bound = upper_bound,
212213
maximize = maximize,
214+
prop = system_property(spec),
213215
)
214216
step_postprocess_value_function!(value_function, spec)
215217
step_postprocess_strategy_cache!(strategy_cache)
@@ -255,7 +257,16 @@ function nextiteration!(V)
255257
return V
256258
end
257259

258-
function step!(workspace, strategy_cache, value_function, k, mp; upper_bound, maximize)
260+
function step!(
261+
workspace,
262+
strategy_cache,
263+
value_function,
264+
k,
265+
mp;
266+
upper_bound,
267+
maximize,
268+
prop,
269+
)
259270
bellman!(
260271
workspace,
261272
select_strategy_cache(strategy_cache, k),
@@ -265,6 +276,7 @@ function step!(workspace, strategy_cache, value_function, k, mp; upper_bound, ma
265276
select_available_actions(available_actions(mp), k);
266277
upper_bound = upper_bound,
267278
maximize = maximize,
279+
prop = prop,
268280
)
269281
end
270282

src/specification.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ time_horizon(prop::FiniteTimeDFAReachability) = prop.time_horizon
144144
Return the set of DFA states with respect to which to compute reachbility for a finite time DFA reachability property.
145145
"""
146146
reach(prop::FiniteTimeDFAReachability) = prop.reach
147+
terminal(prop::FiniteTimeDFAReachability) = reach(prop)
147148

148149
function showproperty(io::IO, first_prefix, prefix, prop::FiniteTimeDFAReachability)
149150
println(io, first_prefix, styled"{code:FiniteTimeDFAReachability}")
@@ -193,6 +194,7 @@ convergence_eps(prop::InfiniteTimeDFAReachability) = prop.convergence_eps
193194
Return the set of DFA states with respect to which to compute reachbility for a infinite time DFA reachability property.
194195
"""
195196
reach(prop::InfiniteTimeDFAReachability) = prop.reach
197+
terminal(prop::InfiniteTimeDFAReachability) = reach(prop)
196198

197199
function showproperty(io::IO, first_prefix, prefix, prop::InfiniteTimeDFAReachability)
198200
println(io, first_prefix, styled"{code:InfiniteTimeDFAReachability}")
@@ -277,6 +279,7 @@ time_horizon(prop::FiniteTimeDFASafety) = prop.time_horizon
277279
Return the set of DFA states with respect to which to compute safety for a finite time DFA safety property.
278280
"""
279281
avoid(prop::FiniteTimeDFASafety) = prop.avoid
282+
terminal(prop::FiniteTimeDFASafety) = avoid(prop)
280283

281284
function showproperty(io::IO, first_prefix, prefix, prop::FiniteTimeDFASafety)
282285
println(io, first_prefix, styled"{code:FiniteTimeDFASafety}")
@@ -325,6 +328,7 @@ convergence_eps(prop::InfiniteTimeDFASafety) = prop.convergence_eps
325328
Return the set of DFA states with respect to which to compute safety for a infinite time DFA safety property.
326329
"""
327330
avoid(prop::InfiniteTimeDFASafety) = prop.avoid
331+
terminal(prop::InfiniteTimeDFASafety) = avoid(prop)
328332

329333
function showproperty(io::IO, first_prefix, prefix, prop::InfiniteTimeDFASafety)
330334
println(io, first_prefix, styled"{code:InfiniteTimeDFASafety}")

src/strategy.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
abstract type AbstractStrategy end
22

33
struct NoStrategy <: AbstractStrategy end
4-
checkstrategy(::NoStrategy, system) = nothing
4+
checkstrategy(::NoStrategy, system, prop = nothing) = nothing
55

66
"""
77
StationaryStrategy
@@ -14,20 +14,25 @@ end
1414
Base.getindex(strategy::StationaryStrategy, k) = strategy.strategy
1515
time_length(::StationaryStrategy) = typemax(Int64)
1616

17-
function checkstrategy(strategy::StationaryStrategy, system)
18-
checkstrategy(strategy.strategy, system)
17+
function checkstrategy(strategy::StationaryStrategy, system, prop = nothing)
18+
checkstrategy(strategy.strategy, system, prop)
1919
end
2020

21-
function checkstrategy(strategy::AbstractArray, system::ProductProcess)
21+
function checkstrategy(strategy::AbstractArray, system::ProductProcess, prop = nothing)
2222
mp = markov_process(system)
2323
dfa = automaton(system)
2424

2525
for state in dfa
26-
checkstrategy(selectdim(strategy, ndims(strategy), state), mp)
26+
# If a DFA property is given, skip terminal states
27+
if !isnothing(prop) && state terminal(prop)
28+
continue
29+
end
30+
31+
checkstrategy(selectdim(strategy, ndims(strategy), state), mp, prop)
2732
end
2833
end
2934

30-
function checkstrategy(strategy::AbstractArray, system::FactoredRMDP)
35+
function checkstrategy(strategy::AbstractArray, system::FactoredRMDP, prop = nothing)
3136
if size(strategy) != source_shape(system)
3237
throw(
3338
DimensionMismatch(
@@ -77,9 +82,9 @@ end
7782
Base.getindex(strategy::TimeVaryingStrategy, k) = strategy.strategy[k]
7883
time_length(strategy::TimeVaryingStrategy) = length(strategy.strategy)
7984

80-
function checkstrategy(strategy::TimeVaryingStrategy, system)
85+
function checkstrategy(strategy::TimeVaryingStrategy, system, prop = nothing)
8186
for strategy_step in strategy.strategy
82-
checkstrategy(strategy_step, system)
87+
checkstrategy(strategy_step, system, prop)
8388
end
8489
end
8590

0 commit comments

Comments
 (0)