Skip to content

Commit f3697ac

Browse files
authored
Merge pull request #88 from Zinoex/fm/commonsolve
Use CommonSolve standardized interface
2 parents 0a12524 + 0c67cfe commit f3697ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1484
-1289
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "IntervalMDP"
22
uuid = "051c988a-e73c-45a4-90ec-875cac0402c7"
33
authors = ["Frederik Baymler Mathiesen <frederik@baymler.com> and contributors"]
4-
version = "0.5.0"
4+
version = "0.6.0"
55

66
[deps]
7+
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
78
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
@@ -21,6 +22,7 @@ IntervalMDPCudaExt = ["Adapt", "CUDA", "GPUArrays", "LLVM"]
2122
[compat]
2223
Adapt = "4"
2324
CUDA = "5.1"
25+
CommonSolve = "0.2.4"
2426
GPUArrays = "10, 11"
2527
JSON = "0.21.4"
2628
LLVM = "7, 8, 9"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ imc = IntervalMarkovChain(prob, initial_states)
6363
target_set = [3]
6464
prop = FiniteTimeReachability(target_set, 10) # Time steps
6565
spec = Specification(prop, Pessimistic, Maximize)
66-
problem = Problem(imc, spec)
66+
problem = VerificationProblem(imc, spec)
6767

6868
# Solve
69-
V, k, residual = value_iteration(problem)
69+
V, k, residual = solve(problem)
7070
```
7171

7272
See [Usage](https://www.baymler.com/IntervalMDP.jl/dev/usage/) for more information about different specifications, using sparse matrices, and CUDA.

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ makedocs(;
2323
"Reference" => Any[
2424
"Systems" => "reference/systems.md",
2525
"Specifications" => "reference/specifications.md",
26-
"Value Iteration" => "reference/value_iteration.md",
26+
"Solve Interface" => "reference/solve.md",
2727
"Data Storage" => "reference/data.md",
2828
],
2929
"Index" => "api.md",

docs/src/reference/solve.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Solve Interface
2+
3+
```@docs
4+
solve
5+
residual
6+
num_iterations
7+
value_function
8+
StationaryStrategy
9+
TimeVaryingStrategy
10+
```
11+
12+
## VI-like Algorithms
13+
14+
```@docs
15+
RobustValueIteration
16+
```

docs/src/reference/specifications.md

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Problem
22

33
```@docs
4-
Problem
4+
VerificationProblem
5+
ControlSynthesisProblem
56
system
67
specification
78
strategy
@@ -19,8 +20,6 @@ StrategyMode
1920
## DFA Reachability
2021

2122
```@docs
22-
AbstractDFAReachability
23-
2423
FiniteTimeDFAReachability
2524
isfinitetime(prop::FiniteTimeDFAReachability)
2625
terminal_states(prop::FiniteTimeDFAReachability)
@@ -37,8 +36,6 @@ convergence_eps(prop::InfiniteTimeDFAReachability)
3736
## Reachability
3837

3938
```@docs
40-
AbstractReachability
41-
4239
FiniteTimeReachability
4340
isfinitetime(prop::FiniteTimeReachability)
4441
terminal_states(prop::FiniteTimeReachability)
@@ -61,8 +58,6 @@ time_horizon(prop::ExactTimeReachability)
6158
## Reach-avoid
6259

6360
```@docs
64-
AbstractReachAvoid
65-
6661
FiniteTimeReachAvoid
6762
isfinitetime(prop::FiniteTimeReachAvoid)
6863
terminal_states(prop::FiniteTimeReachAvoid)
@@ -88,8 +83,6 @@ time_horizon(prop::ExactTimeReachAvoid)
8883
## Safety
8984

9085
```@docs
91-
AbstractSafety
92-
9386
FiniteTimeSafety
9487
isfinitetime(prop::FiniteTimeSafety)
9588
terminal_states(prop::FiniteTimeSafety)
@@ -106,8 +99,6 @@ convergence_eps(prop::InfiniteTimeSafety)
10699
## Reward specification
107100

108101
```@docs
109-
AbstractReward
110-
111102
FiniteTimeReward
112103
isfinitetime(prop::FiniteTimeReward)
113104
reward(prop::FiniteTimeReward)
@@ -124,8 +115,6 @@ convergence_eps(prop::InfiniteTimeReward)
124115
## Hitting time
125116

126117
```@docs
127-
AbstractHittingTime
128-
129118
ExpectedExitTime
130119
isfinitetime(prop::ExpectedExitTime)
131120
terminal_states(prop::ExpectedExitTime)

docs/src/reference/value_iteration.md

Lines changed: 0 additions & 21 deletions
This file was deleted.

docs/src/usage.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,19 @@ spec = Specification(prop, Optimistic, Maximize)
114114
spec = Specification(prop, Optimistic, Minimize)
115115

116116
## Combine system and specification in a Problem
117-
problem = Problem(imdp_or_imc, spec)
117+
problem = VerificationProblem(imdp_or_imc, spec)
118118
```
119119

120-
Finally, we call `value_iteration` to solve the specification. `value_iteration` returns the value function for all states in addition to the number of iterations performed and the last Bellman residual.
120+
Finally, we call [`solve`](@ref) to solve the specification. `solve` returns the value function for all states in addition to the number of iterations performed and the last Bellman residual, wrapped in a solution object.
121121

122122
```julia
123-
V, k, residual = value_iteration(problem)
123+
sol = solve(problem) # or solve(problem, RobustValueIteration())
124+
V, k, res = sol
125+
126+
# or alternatively
127+
V, k, res = value_function(sol), num_iterations(sol), residual(sol)
124128
```
129+
For now, only [`RobustValueIteration`](@ref) is supported, but more algorithms are planned.
125130

126131
!!! note
127132
To use multi-threading for parallelization, you need to either start julia with `julia --threads <n|auto>` where `n` is a positive integer or to set the environment variable `JULIA_NUM_THREADS` to the number of threads you want to use. For more information, see [Multi-threading](https://docs.julialang.org/en/v1/manual/multi-threading/).

ext/IntervalMDPCudaExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ IntervalMDP.arrayfactory(
121121
::MR,
122122
T,
123123
num_states,
124-
) where {R, MR <: Union{CuSparseMatrixCSC{R}, CuMatrix{R}}} = CUDA.zeros(T, num_states)
124+
) where {R, MR <: Union{CuSparseMatrixCSC{R}, CuArray{R}}} = CUDA.zeros(T, num_states)
125125

126126
include("cuda/utils.jl")
127127
include("cuda/array.jl")

ext/cuda/specification.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
Adapt.@adapt_structure FiniteTimeReward
33
Adapt.@adapt_structure InfiniteTimeReward
44
Adapt.@adapt_structure Specification
5-
Adapt.@adapt_structure Problem
5+
Adapt.@adapt_structure VerificationProblem
6+
Adapt.@adapt_structure ControlSynthesisProblem

src/Data/bmdp-tool.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ function read_bmdp_tool_file(path)
110110
end
111111

112112
"""
113-
write_bmdp_tool_file(path, problem::Problem)
113+
write_bmdp_tool_file(path, problem::IntervalMDP.AbstractIntervalMDPProblem)
114114
115115
Write a bmdp-tool transition probability file for the given an IMDP and a reachability specification.
116116
The file will not contain enough information to specify a reachability specification. The remaining
117117
parameters are rather command line arguments.
118118
119119
See [Data storage formats](@ref) for more information on the file format.
120120
"""
121-
write_bmdp_tool_file(path, problem::Problem) =
121+
write_bmdp_tool_file(path, problem::IntervalMDP.AbstractIntervalMDPProblem) =
122122
write_bmdp_tool_file(path, system(problem), specification(problem))
123123

124124
"""
@@ -130,8 +130,11 @@ write_bmdp_tool_file(path, mdp::IntervalMarkovProcess, spec::Specification) =
130130
"""
131131
write_bmdp_tool_file(path, mdp::IntervalMarkovProcess, prop::AbstractReachability)
132132
"""
133-
write_bmdp_tool_file(path, mdp::IntervalMarkovProcess, prop::AbstractReachability) =
134-
write_bmdp_tool_file(path, mdp, reach(prop))
133+
write_bmdp_tool_file(
134+
path,
135+
mdp::IntervalMarkovProcess,
136+
prop::IntervalMDP.AbstractReachability,
137+
) = write_bmdp_tool_file(path, mdp, reach(prop))
135138

136139
"""
137140
write_bmdp_tool_file(path, mdp::IntervalMarkovProcess, terminal_states::Vector{T})

0 commit comments

Comments
 (0)