-
Notifications
You must be signed in to change notification settings - Fork 3
Probabilistic Labelling for ProductProcesses #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
9042914
WIP: ProbabilisticLabelling
JeffersonYeh 7bb56b1
adjust ProductProcess to generalize to probabilistic label mapping ma…
JeffersonYeh 7f36b6e
WIP: bellman for ProductProcess with Probabilistic Labelling
JeffersonYeh c5a64ab
WIP
Zinoex 26661b9
WIP
Zinoex 45f5cf5
restructure Labelling classes and ran formatter
JeffersonYeh 9821670
Fix labelling tests
JeffersonYeh 792fb2a
Merge branch 'stochastic-labelling' of github.com:Zinoex/IntervalMDP.…
Zinoex 128f01a
Merge branch 'fm/fimdp' into stochastic-labelling
Zinoex 6921802
Merge branch 'fm/fimdp' into stochastic-labelling
Zinoex 264b2d1
Merge branch 'fm/fimdp' into stochastic-labelling
Zinoex 48cb07b
minor changes from review
JeffersonYeh a3f1bae
Add tests for ProbabilisticLabelling and minor updates to other label…
JeffersonYeh 62d8800
Update docs for DeterministicLabelling
JeffersonYeh 57a9f39
Merge branch 'fm/fimdp' into stochastic-labelling
JeffersonYeh 6ff4c0b
Add Probabilistically Labelled prodIMDP tests for bellman (over all s…
JeffersonYeh fb97189
Improve strategy cache initialization in product process bellman tests
JeffersonYeh e45b292
Merge branch 'fm/fimdp' into stochastic-labelling
Zinoex 28fe368
Fix show test for ProductProcess
Zinoex 3617936
Add product process bellman tests with optimal and nonoptimal strateg…
JeffersonYeh a0db4b2
Add probabilistic labelling docs
Zinoex da06562
Merge branch 'main' into stochastic-labelling
Zinoex 11bfb4c
Format workspace
Zinoex File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| """ | ||
| struct DeterministicLabelling{ | ||
| T <: Integer, | ||
| AT <: AbstractArray{T} | ||
| } | ||
|
|
||
| A type representing the labelling of IMDP states into DFA inputs. | ||
|
|
||
| Formally, let ``L : S \\to 2^{AP}`` be a labelling function, where | ||
| - ``S`` is the set of IMDP states, and | ||
| - ``2^{AP}`` is the power set of atomic propositions | ||
|
|
||
| Then the ```DeterministicLabelling``` type is defined as vector which stores the mapping. | ||
|
|
||
| ### Fields | ||
| - `map::AT`: mapping function where indices are (factored) IMDP states and stored values are DFA inputs. | ||
| - `num_outputs::Int32`: number of labels accounted for in mapping. | ||
|
|
||
| """ | ||
| struct DeterministicLabelling{T <: Integer, AT <: AbstractArray{T}} <: AbstractLabelling | ||
| map::AT | ||
| num_outputs::Int32 | ||
|
|
||
| function DeterministicLabelling(map::AT) where {T <: Integer, AT <: AbstractArray{T}} | ||
| num_outputs = checklabelling(map) | ||
|
|
||
| return new{T, AT}(map, Int32(num_outputs)) | ||
| end | ||
| end | ||
|
|
||
| function checklabelling(map::AbstractArray{<:Integer}) | ||
| labels = unique(map) | ||
|
|
||
| if any(labels .< 1) | ||
| throw(ArgumentError("Labelled state index cannot be less than 1")) | ||
| end | ||
|
|
||
| # Check that labels are consecutive integers | ||
| sort!(labels) | ||
| if any(diff(labels) .!= 1) | ||
| throw(ArgumentError("Labelled state indices must be consecutive integers")) | ||
| end | ||
|
|
||
| return last(labels) | ||
| end | ||
|
|
||
| """ | ||
| mapping(dl::DeterministicLabelling) | ||
|
|
||
| Return the mapping array of the labelling function. | ||
| """ | ||
| mapping(dl::DeterministicLabelling) = dl.map | ||
|
|
||
| """ | ||
| size(dl::DeterministicLabelling) | ||
|
|
||
| Returns the shape of the input range of the labeling function ``L : S \\to 2^{AP}``, which can be multiple dimensions in case of factored IMDPs. | ||
| """ | ||
| Base.size(dl::DeterministicLabelling) = size(dl.map) | ||
|
|
||
| """ | ||
| num_labels(dl::DeterministicLabelling) | ||
| Return the number of labels (DFA inputs) in the labelling function. | ||
| """ | ||
| num_labels(dl::DeterministicLabelling) = dl.num_outputs | ||
|
|
||
| """ | ||
| state_values(dl::DeterministicLabelling) | ||
| Return a tuple with the number of states for each state variable of the labeling function ``L : S \\to 2^{AP}``, which can be multiple dimensions in case of factored IMDPs. | ||
| """ | ||
| state_values(dl::DeterministicLabelling) = size(dl.map) | ||
|
|
||
| """ | ||
| num_states(dl::DeterministicLabelling) | ||
| Return the number of states of the labeling function ``L : S \\to 2^{AP}`` | ||
| """ | ||
| num_states(dl::DeterministicLabelling) = prod(state_values(dl)) | ||
|
|
||
| """ | ||
| getindex(dl::DeterministicLabelling, s...) | ||
|
|
||
| Return the label for state s. | ||
| """ | ||
| Base.getindex(dl::DeterministicLabelling, s...) = dl.map[s...] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,74 +1,6 @@ | ||
| abstract type AbstractLabelling end | ||
|
|
||
| """ | ||
| struct LabellingFunction{ | ||
| T <: Integer, | ||
| AT <: AbstractArray{T} | ||
| } | ||
|
|
||
| A type representing the labelling of IMDP states into DFA inputs. | ||
|
|
||
| Formally, let ``L : S \\to 2^{AP}`` be a labelling function, where | ||
| - ``S`` is the set of IMDP states, and | ||
| - ``2^{AP}`` is the power set of atomic propositions | ||
|
|
||
| Then the ```LabellingFunction``` type is defined as vector which stores the mapping. | ||
|
|
||
| ### Fields | ||
| - `map::AT`: mapping function where indices are (factored) IMDP states and stored values are DFA inputs. | ||
| - `num_inputs::Int32`: number of IMDP states accounted for in mapping. | ||
|
|
||
| """ | ||
| struct LabellingFunction{T <: Integer, AT <: AbstractArray{T}} <: AbstractLabelling | ||
| map::AT | ||
| num_outputs::Int32 | ||
| end | ||
|
|
||
| function LabellingFunction(map::AT) where {T <: Integer, AT <: AbstractArray{T}} | ||
| num_outputs = checklabelling(map) | ||
|
|
||
| return LabellingFunction(map, Int32(num_outputs)) | ||
| end | ||
|
|
||
| function checklabelling(map::AbstractArray{<:Integer}) | ||
| labels = unique(map) | ||
|
|
||
| if any(labels .< 1) | ||
| throw(ArgumentError("Labelled state index cannot be less than 1")) | ||
| end | ||
|
|
||
| # Check that labels are consecutive integers | ||
| sort!(labels) | ||
| if any(diff(labels) .!= 1) | ||
| throw(ArgumentError("Labelled state indices must be consecutive integers")) | ||
| end | ||
|
|
||
| return last(labels) | ||
| end | ||
|
|
||
| """ | ||
| mapping(labelling_func::LabellingFunction) | ||
|
|
||
| Return the mapping array of the labelling function. | ||
| """ | ||
| mapping(labelling_func::LabellingFunction) = labelling_func.map | ||
|
|
||
| """ | ||
| size(labelling_func::LabellingFunction) | ||
|
|
||
| Returns the shape of the input range of the labeling function ``L : S \\to 2^{AP}``, which can be multiple dimensions in case of factored IMDPs. | ||
| """ | ||
| Base.size(labelling_func::LabellingFunction) = size(labelling_func.map) | ||
|
|
||
| """ | ||
| num_labels(labelling_func::LabellingFunction) | ||
| Return the number of labels (DFA inputs) in the labelling function. | ||
| """ | ||
| num_labels(labelling_func::LabellingFunction) = labelling_func.num_outputs | ||
| AbstractLabelling | ||
|
|
||
| An abstract type for labelling functions. | ||
| """ | ||
| getindex(lf::LabellingFunction, s...) | ||
|
|
||
| Return the label for state s. | ||
| """ | ||
| Base.getindex(lf::LabellingFunction, s...) = lf.map[s...] | ||
| abstract type AbstractLabelling end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| """ | ||
| struct ProbabilisticLabelling{ | ||
| R <: Real, | ||
| MR <: AbstractMatrix{R} | ||
| } | ||
|
|
||
| A type representing the Probabilistic labelling of IMDP states into DFA inputs. Each labelling is assigned a probability. | ||
|
|
||
| Formally, let ``L : S \\times 2^{AP} \\to [0, 1]`` be a labelling function, where | ||
| - ``S`` is the set of IMDP states, and | ||
| - ``2^{AP}`` is the power set of atomic propositions | ||
|
|
||
| Then the ```ProbabilisticLabelling``` type is defined as matrix which stores the mapping. | ||
|
|
||
| ### Fields | ||
| - `map::MT`: mapping function encoded as matrix with labels on the rows, IMDP states on the columns, and valid probability values for the destination. | ||
|
|
||
| The choice to have labels on the rows is due to the column-major storage of matrices in Julia and the fact that we want the inner loop over DFA target states | ||
| in the Bellman operator `bellman!`. | ||
|
|
||
| """ | ||
| struct ProbabilisticLabelling{R <: Real, MR <: AbstractMatrix{R}} <: AbstractLabelling | ||
| map::MR | ||
|
|
||
| function ProbabilisticLabelling(map::MR) where {R <: Real, MR <: AbstractMatrix{R}} | ||
| checklabellingprobs(map) | ||
|
|
||
| return new{R, MR}(map) | ||
| end | ||
| end | ||
|
|
||
| function checklabellingprobs(map::AbstractMatrix{<:Real}) | ||
|
|
||
| # check for each state, all the labels probabilities sum to 1 | ||
| if any(sum(map; dims=1) .!= 1) | ||
| throw( | ||
| ArgumentError( | ||
| "For each IMDP state, probabilities over label states must sum to 1", | ||
| ), | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| mapping(pl::ProbabilisticLabelling) | ||
|
|
||
| Return the mapping matrix of the probabilistic labelling function. | ||
| """ | ||
| mapping(pl::ProbabilisticLabelling) = pl.map | ||
|
|
||
| Base.size(pl::ProbabilisticLabelling) = size(pl.map) | ||
| Base.size(pl::ProbabilisticLabelling, i) = size(pl.map, i) | ||
|
|
||
| """ | ||
| getindex(pl::ProbabilisticLabelling, s, l) | ||
|
|
||
| Return the probabilities for labelling l from state s. | ||
| """ | ||
| Base.getindex(pl::ProbabilisticLabelling, s, l) = pl.map[l, s] | ||
|
|
||
| """ | ||
| getindex(pl::ProbabilisticLabelling, s) | ||
|
|
||
| Return the probabilities over labels from state s. | ||
| """ | ||
| Base.getindex(pl::ProbabilisticLabelling, s) = @view(pl.map[:, s]) | ||
|
|
||
| """ | ||
| num_labels(pl::ProbabilisticLabelling) | ||
| Return the number of labels (DFA inputs) in the probabilistic labelling function. | ||
| """ | ||
| num_labels(pl::ProbabilisticLabelling) = size(pl.map, 1) | ||
|
|
||
| """ | ||
| state_values(pl::ProbabilisticLabelling) | ||
| Return a tuple with the number of states for each state variable of the labeling function ``L : S \\to 2^{AP}``, which can be multiple dimensions in case of factored IMDPs. | ||
| """ | ||
| state_values(pl::ProbabilisticLabelling) = Base.tail(size(pl.map)) | ||
|
|
||
| """ | ||
| num_states(pl::ProbabilisticLabelling) | ||
| Return the number of states of the labeling function ``L : S \\to 2^{AP}`` | ||
| """ | ||
| num_states(pl::ProbabilisticLabelling) = prod(state_values(pl)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.