Skip to content

Commit 1054b8f

Browse files
authored
Compatibility with StateSpaceSets v2 (#392)
* update compat with statespacesets 2 * Depend on AbstractStateSpaceSet, not deprecated AbstractDataset * sssets v2 * temporary constructor for multiple dataset inputs the current statespacesets v2.3 constructor is way too slow * Fix transfer entropy transfer operator estimation * shuffle order, so utils are loaded first * horizontal concatenation * Use correct name, not the deprecated one * Use local variables * Explicitly test vector + statespaceset input * move agreement test to utils * Add more OCE tests * Fix marginal construction * use local variables * Improve docstrings * Update changelog * Set minimum julia version to LTS release
1 parent 27a2bb8 commit 1054b8f

File tree

21 files changed

+157
-51
lines changed

21 files changed

+157
-51
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name = "Associations"
22
uuid = "614afb3a-e278-4863-8805-9959372b9ec2"
33
authors = ["Kristian Agasøster Haaga <[email protected]>", "Tor Einar Møller <[email protected]>", "George Datseris <[email protected]>"]
44
repo = "https://github.com/kahaaga/Associations.jl.git"
5-
version = "4.2.0"
5+
version = "4.3.0"
66

77
[deps]
88
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -32,23 +32,23 @@ TimeseriesSurrogates = "c804724b-8c18-5caa-8579-6025a0767c70"
3232
[compat]
3333
Accessors = "^0.1.28"
3434
Combinatorics = "1"
35-
ComplexityMeasures = "3.6.5"
35+
ComplexityMeasures = "3.7.3"
3636
DSP = "^0.7"
37-
DelayEmbeddings = "2.7"
37+
DelayEmbeddings = "2.8"
3838
Distances = "^0.10"
3939
Distributions = "^0.25"
4040
Graphs = "^1.11"
4141
HypothesisTests = "^0.11"
4242
Neighborhood = "^0.2.4"
4343
ProgressMeter = "1.10"
44-
RecurrenceAnalysis = "2"
44+
RecurrenceAnalysis = "2.1"
4545
Reexport = "1"
4646
Scratch = "1"
4747
SpecialFunctions = "2"
48-
StateSpaceSets = "^1.5"
48+
StateSpaceSets = "2.1"
4949
StaticArrays = "^1"
5050
Statistics = "1"
5151
StatsBase = "^0.34"
5252
StyledStrings = "1"
53-
TimeseriesSurrogates = "2.6"
54-
julia = "^1.10"
53+
TimeseriesSurrogates = "2.7"
54+
julia = "^1.10.6"

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
From version v4.0 onwards, this package has been renamed to to Associations.jl.
44

5+
# 4.3
6+
7+
- Compatiblity with StateSpaceSets.jl v2.X
8+
- Improved documentation strings for `GaoOhViswanath` and `GaoKannanOhViswanath`
9+
510
# 4.2
611

712
- New association measure: `AzadkiaChatterjeeCoefficient`.

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ TimeseriesSurrogates = "c804724b-8c18-5caa-8579-6025a0767c70"
2626

2727
[compat]
2828
DynamicalSystemsBase = "3"
29-
julia = "^1.6"
29+
julia = "^1.10.6"

src/Associations.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@ module Associations
1010
using Reexport
1111

1212
using StateSpaceSets
13+
1314
using DelayEmbeddings: embed, genembed
1415
export embed, genembed
1516

1617
import HypothesisTests: pvalue
1718
export trajectory
18-
@reexport using StateSpaceSets
19-
@reexport using ComplexityMeasures
20-
@reexport using TimeseriesSurrogates
2119

22-
include("utils/utils.jl")
20+
2321
include("core.jl")
22+
include("utils/utils.jl")
2423

2524
include("methods/information/information.jl")
2625
include("methods/crossmappings/crossmappings.jl")
@@ -37,6 +36,9 @@ module Associations
3736

3837
include("deprecations/deprecations.jl")
3938

39+
@reexport using StateSpaceSets
40+
@reexport using ComplexityMeasures
41+
@reexport using TimeseriesSurrogates
4042
# Update messages:
4143
using Scratch
4244
display_update = true

src/causal_graphs/oce/OCE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function infer_graph(alg::OCE, x; verbose = true)
8989
return select_parents(alg, x; verbose)
9090
end
9191

92-
function infer_graph(alg::OCE, x::AbstractDataset; verbose = true)
92+
function infer_graph(alg::OCE, x::AbstractStateSpaceSet; verbose = true)
9393
return infer_graph(alg, columns(x); verbose)
9494
end
9595

src/methods/information/definitions/mutual_informations/mutual_informations.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,12 @@ end
3232

3333
function marginal_entropies_mi3h_discrete(est::EntropyDecomposition{<:MutualInformation, <:DiscreteInfoEstimator}, x, y)
3434
# Encode marginals to integers based on the outcome space.
35-
eX, eY = codified_marginals(est.discretization, x, y)
36-
eXY = StateSpaceSet(eX, eY)
35+
eX::StateSpaceSet, eY::StateSpaceSet = StateSpaceSet.(codified_marginals(est.discretization, x, y))
36+
eXY::StateSpaceSet = StateSpaceSet(eX, eY)
3737

3838
# The outcome space is no longer relevant from this point on. We're done discretizing,
3939
# so now we can just count (i.e. use `UniqueElements` as the outcome space).
4040
o = UniqueElements()
41-
4241
modified_est = estimator_with_overridden_parameters(est.definition, est.est)
4342
HX = information(modified_est, est.pest, o, eX) # estimates entropy in the X marginal
4443
HY = information(modified_est, est.pest, o, eY) # estimates entropy in the Y marginal

src/methods/information/definitions/transferentropy/transferoperator.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import ComplexityMeasures: TransferOperator, invariantmeasure, InvariantMeasure, Probabilities
32
using ComplexityMeasures.GroupSlices
43
export TransferOperator
@@ -47,19 +46,9 @@ end
4746

4847
function _marginal_encodings(encoder::RectangularBinEncoding, x::VectorOrStateSpaceSet...)
4948
X = StateSpaceSet(StateSpaceSet.(x)...)
50-
bins = [vec(encode_as_tuple(encoder, pt))' for pt in X]
49+
bins = [vec(encode_as_tuple(encoder, pt))' for pt in unique(X.data)]
5150
joint_bins = reduce(vcat, bins)
52-
idxs = size.(x, 2) #each input can have different dimensions
53-
s = 1
54-
encodings = Vector{StateSpaceSet}(undef, length(idxs))
55-
for (i, cidx) in enumerate(idxs)
56-
variable_subset = s:(s + cidx - 1)
57-
s += cidx
58-
y = @views joint_bins[:, variable_subset]
59-
encodings[i] = StateSpaceSet(y)
60-
end
61-
62-
return encodings
51+
return StateSpaceSet(joint_bins)
6352
end
6453

6554
# Only works for `RelativeAmount`, because probabilities are obtained from the
@@ -88,9 +77,11 @@ function h4_marginal_probs(
8877
# marginals, not a single encoding integer for each bin. Otherwise, we can't
8978
# properly subset marginals here and relate them to the approximated invariant measure.
9079
encoding = iv.to.encoding
80+
# Visited bins (absolute coordinates)
9181
visited_bins_coordinates = StateSpaceSet(decode.(Ref(encoding), iv.to.bins))
92-
unique_visited_bins = _marginal_encodings(iv.to.encoding, visited_bins_coordinates)[1]
9382

83+
# Visited bins (coordinates encoded to integers using rectangular encoding)
84+
unique_visited_bins = _marginal_encodings(iv.to.encoding, visited_bins_coordinates)
9485
# # The subset of visited bins with nonzero measure
9586
inds_non0measure = findall(iv.ρ .> 0)
9687
positive_measure_bins = unique_visited_bins[inds_non0measure]

src/methods/information/estimators/mutual_info_estimators/GaoKannanOhViswanath.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export GaoKannanOhViswanath
44

55
"""
66
GaoKannanOhViswanath <: MutualInformationEstimator
7-
GaoKannanOhViswanath(; k = 1, w = 0)
7+
GaoKannanOhViswanath(definition = MIShannon(); k = 1, w = 0)
88
99
The `GaoKannanOhViswanath` (Shannon) estimator is designed for estimating
1010
Shannon mutual information between variables that may be either discrete, continuous or
@@ -14,6 +14,14 @@ a mixture of both [GaoKannanOhViswanath2017](@cite).
1414
1515
- [`MIShannon`](@ref)
1616
17+
## Keyword arguments
18+
19+
- **`k::Int`**: The number of nearest neighbors to consider. Only information about the
20+
`k`-th nearest neighbor is actually used.
21+
- **`w::Int`**: The Theiler window, which determines if temporal neighbors are excluded
22+
during neighbor searches in the joint space. Defaults to `0`, meaning that only the
23+
point itself is excluded.
24+
1725
## Usage
1826
1927
- Use with [`association`](@ref) to compute Shannon mutual information from input data.

src/methods/information/estimators/mutual_info_estimators/GaoOhViswanath.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export GaoOhViswanath
33

44
"""
55
GaoOhViswanath <: MutualInformationEstimator
6+
GaoOhViswanath(definition = MIShannon(); k = 1, w = 0)
67
78
The `GaoOhViswanath` is a mutual information estimator based on nearest neighbors,
89
and is also called the bias-improved-KSG estimator, or BI-KSG, by [Gao2018](@cite).
@@ -11,6 +12,14 @@ and is also called the bias-improved-KSG estimator, or BI-KSG, by [Gao2018](@cit
1112
1213
- [`MIShannon`](@ref)
1314
15+
## Keyword arguments
16+
17+
- **`k::Int`**: The number of nearest neighbors to consider. Only information about the
18+
`k`-th nearest neighbor is actually used.
19+
- **`w::Int`**: The Theiler window, which determines if temporal neighbors are excluded
20+
during neighbor searches in the joint space. Defaults to `0`, meaning that only the
21+
point itself is excluded.
22+
1423
## Usage
1524
1625
- Use with [`association`](@ref) to compute Shannon mutual information from input data.

src/methods/information/estimators/mutual_info_estimators/KSG1.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,13 @@ const KSG1 = KraskovStögbauerGrassberger1
6868

6969
function association(est::KSG1{<:MIShannon}, x::VectorOrStateSpaceSet...)
7070
verify_number_of_inputs_vars(est.definition, length(x))
71-
7271
(; definition, k, w, metric_joint, metric_marginals) = est
73-
joint = StateSpaceSet(x...)
72+
7473
marginals = map(xᵢ -> StateSpaceSet(xᵢ), x)
74+
# Note: this uses a StateSpaceSet constructor that is overloaded from StateSpaceSets.jl, because the native
75+
# one is extremely slow.
76+
joint::StateSpaceSet = StateSpaceSet(marginals...)
77+
7578
M = length(x)
7679
N = length(joint)
7780

0 commit comments

Comments
 (0)