Skip to content

Commit f5dacf8

Browse files
committed
missed staging for git
1 parent 0713c74 commit f5dacf8

14 files changed

+357
-278
lines changed

ext/FluxFactorsExt.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1+
module FluxFactorsExt
12

2-
@info "IncrementalInference is adding Flux related functionality."
33

4-
# the factor definitions
5-
export FluxModelsDistribution
6-
export MixtureFluxModels
4+
@info "IncrementalInference is loading extension functionality related to Flux.jl"
75

86
# Required packages
9-
using .Flux
7+
using Flux
108
using DataStructures: OrderedDict
11-
using Random, Statistics
9+
using LinearAlgebra
10+
using Base64
11+
12+
import Base: convert
1213

1314
# import Base: convert
15+
using Random, Statistics
1416
import Random: rand
1517

18+
using IncrementalInference
19+
import IncrementalInference: samplePoint, sampleTangent
20+
21+
# the factor definitions
22+
# export FluxModelsDistribution
23+
export MixtureFluxModels
24+
1625
const _IIFListTypes = Union{<:AbstractVector, <:Tuple, <:NTuple, <:NamedTuple}
1726

1827
function Random.rand(nfb::FluxModelsDistribution, N::Integer = 1)
@@ -164,3 +173,8 @@ function MixtureFluxModels(::Type{F}, w...; kw...) where {F <: AbstractFactor}
164173
end
165174

166175
#
176+
177+
include("FluxModelsSerialization.jl")
178+
179+
180+
end # module

ext/FluxModelsSerialization.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
# @info "IncrementalInference is adding Flux/BSON serialization functionality."
44

5-
using Base64
6-
7-
import Base: convert
85

96
function _serializeFluxModelBase64(model::Flux.Chain)
107
io = IOBuffer()

ext/HeatmapSampler.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,110 @@
11
# heatmap sampler (experimental)
22

3+
4+
(hmd::HeatmapGridDensity)(w...; kw...) = hmd.densityFnc(w...; kw...)
5+
6+
function sampleTangent(M::AbstractManifold, hms::HeatmapGridDensity)
7+
return sampleTangent(M, hms.densityFnc)
8+
end
9+
10+
function Base.show(io::IO, x::HeatmapGridDensity{T, H, B}) where {T, H, B}
11+
printstyled(io, "HeatmapGridDensity{"; bold = true, color = :blue)
12+
println(io)
13+
printstyled(io, " T"; color = :magenta, bold = true)
14+
println(io, " = ", T)
15+
printstyled(io, " H"; color = :magenta, bold = true)
16+
println(io, "`int = ", H)
17+
printstyled(io, " B"; color = :magenta, bold = true)
18+
println(io, " = ", B)
19+
printstyled(io, " }"; color = :blue, bold = true)
20+
println(io, "(")
21+
println(io, " data: ", size(x.data))
22+
println(
23+
io,
24+
" min/max: ",
25+
round(minimum(x.data); digits = 5),
26+
" / ",
27+
round(maximum(x.data); digits = 5),
28+
)
29+
println(io, " domain: ", size(x.domain[1]), ", ", size(x.domain[2]))
30+
println(
31+
io,
32+
" min/max: ",
33+
round(minimum(x.domain[1]); digits = 5),
34+
" / ",
35+
round(maximum(x.domain[1]); digits = 5),
36+
)
37+
println(
38+
io,
39+
" min/max: ",
40+
round(minimum(x.domain[2]); digits = 5),
41+
" / ",
42+
round(maximum(x.domain[2]); digits = 5),
43+
)
44+
println(io, " bw_factor: ", x.bw_factor)
45+
print(io, " ")
46+
show(io, x.densityFnc)
47+
return nothing
48+
end
49+
50+
Base.show(io::IO, ::MIME"text/plain", x::HeatmapGridDensity) = show(io, x)
51+
Base.show(io::IO, ::MIME"application/prs.juno.inline", x::HeatmapGridDensity) = show(io, x)
52+
53+
"""
54+
$SIGNATURES
55+
56+
Internal function for updating HGD.
57+
58+
Notes
59+
- Likely to be used for [unstashing packed factors](@ref section_stash_unstash) via [`preambleCache`](@ref).
60+
- Counterpart to `AMP._update!` function for stashing of either MKD or HGD.
61+
"""
62+
function _update!(
63+
dst::HeatmapGridDensity{T, H, B},
64+
src::HeatmapGridDensity{T, H, B},
65+
) where {T, H, B}
66+
@assert size(dst.data) == size(src.data) "Updating HeatmapDensityGrid can only be done for data of the same size"
67+
dst.data .= src.data
68+
if !isapprox(dst.domain[1], src.domain[1])
69+
dst.domain[1] .= src.domain[1]
70+
end
71+
if !isapprox(dst.domain[2], src.domain[2])
72+
dst.domain[2] .= src.domain[2]
73+
end
74+
AMP._update!(dst.densityFnc, src.densityFnc)
75+
return dst
76+
end
77+
78+
79+
##
80+
81+
(lsg::LevelSetGridNormal)(w...; kw...) = lsg.densityFnc(w...; kw...)
82+
83+
function sampleTangent(M::AbstractManifold, lsg::LevelSetGridNormal)
84+
return sampleTangent(M, lsg.heatmap.densityFnc)
85+
end
86+
87+
function Base.show(io::IO, x::LevelSetGridNormal{T, H}) where {T, H}
88+
printstyled(io, "LevelSetGridNormal{"; bold = true, color = :blue)
89+
println(io)
90+
printstyled(io, " T"; color = :magenta, bold = true)
91+
println(io, " = ", T)
92+
printstyled(io, " H"; color = :magenta, bold = true)
93+
println(io, "`int = ", H)
94+
printstyled(io, " }"; color = :blue, bold = true)
95+
println(io, "(")
96+
println(io, " level: ", x.level)
97+
println(io, " sigma: ", x.sigma)
98+
println(io, " sig.scale: ", x.sigma_scale)
99+
println(io, " heatmap: ")
100+
show(io, x.heatmap)
101+
return nothing
102+
end
103+
104+
Base.show(io::IO, ::MIME"text/plain", x::LevelSetGridNormal) = show(io, x)
105+
Base.show(io::IO, ::MIME"application/prs.juno.inline", x::LevelSetGridNormal) = show(io, x)
106+
107+
3108
##
4109

5110
getManifold(hgd::HeatmapGridDensity) = getManifold(hgd.densityFnc)

ext/InterpolationsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using ApproxManifoldProducts
1111
import ApproxManifoldProducts: sample
1212
const AMP = ApproxManifoldProducts
1313

14-
import IncrementalInference: getManifold
14+
import Base: show
15+
import IncrementalInference: getManifold, sampleTangent
1516
import IncrementalInference: HeatmapGridDensity, PackedHeatmapGridDensity
1617
import IncrementalInference: LevelSetGridNormal, PackedLevelSetGridNormal
1718

ext/WeakDepsPrototypes.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,4 @@
11

2-
"""
3-
$TYPEDEF
4-
5-
Build a full ODE solution into a relative factor to condense possible sensor data into a relative transformation,
6-
but keeping the parameter estimation process fluid. Assumes first and second variable in order
7-
are of same dimension and compatible manifolds, such that ODE runs from Xi to Xi+1 on all
8-
dimensions. Internal state vector can be decoupled onto different domain as needed.
9-
10-
Notes
11-
- Based on DifferentialEquations.jl
12-
- `getSample` step does the `solve(ODEProblem)` step.
13-
- `tspan` is taken from variables only once at object construction -- i.e. won't detect changed timestamps.
14-
- Regular factor evaluation is done as full dimension `AbstractRelativeRoots`, and is basic linear difference.
15-
16-
DevNotes
17-
- FIXME see 1025, `multihypo=` will not yet work.
18-
- FIXME Lots of consolidation and standardization to do, see RoME.jl #244 regarding Manifolds.jl.
19-
- TODO does not yet handle case where a factor spans across two timezones.
20-
"""
21-
struct DERelative{T <: InferenceVariable, P, D} <: AbstractRelativeMinimize
22-
domain::Type{T}
23-
forwardProblem::P
24-
backwardProblem::P
25-
# second element of this data tuple is additional variables that will be passed down as a parameter
26-
data::D
27-
specialSampler::Function
28-
end
292

303
# InteractiveUtils.jl
314
function getCurrentWorkspaceFactors end

src/ExportAPI.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ export PackedUniform, PackedNormal
310310
export PackedZeroMeanDiagNormal,
311311
PackedZeroMeanFullNormal, PackedDiagNormal, PackedFullNormal
312312
export PackedManifoldKernelDensity
313-
export PackedAliasingScalarSampler, PackedHeatmapGridDensity, PackedLevelSetGridNormal
313+
export PackedAliasingScalarSampler
314314
export PackedRayleigh
315315

316316
export Mixture, PackedMixture
@@ -358,13 +358,23 @@ export makeSolverData!
358358

359359
export MetaPrior
360360

361+
362+
# weakdeps on Interpolations.jl
363+
export HeatmapGridDensity, LevelSetGridNormal
364+
export PackedHeatmapGridDensity, PackedLevelSetGridNormal
365+
366+
# weakdeps on DifferentialEquations.jl
361367
export DERelative
362368

369+
# weakdeps on Flux.jl
370+
export FluxModelsDistribution, PackedFluxModelsDistribution
371+
363372
# weakdeps on InteractiveUtils.jl
364373
export getCurrentWorkspaceFactors, getCurrentWorkspaceVariables
365374
export listTypeTree
366375

367376
# weakdeps on Gadfly.jl
368377
export exportimg, spyCliqMat
369378

370-
#
379+
380+
#

src/IncrementalInference.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ include("entities/FactorOperationalMemory.jl")
130130
include("Factors/GenericMarginal.jl")
131131
# Special belief types for sampling as a distribution
132132
include("entities/AliasScalarSampling.jl")
133-
include("entities/OptionalDensities.jl")
133+
include("entities/OptionalDensities.jl") # used in BeliefTypes.jl::SamplableBeliefs
134134
include("entities/BeliefTypes.jl")
135135

136136
include("services/HypoRecipe.jl")
@@ -234,6 +234,10 @@ include("services/SolverAPI.jl")
234234
# Symbolic tree analysis files.
235235
include("services/AnalysisTools.jl")
236236

237+
# optional densities on weakdeps
238+
include("Serialization/entities/SerializingOptionalDensities.jl")
239+
include("Serialization/services/SerializingOptionalDensities.jl")
240+
237241
include("../ext/WeakDepsPrototypes.jl")
238242

239243
# deprecation legacy support
@@ -253,12 +257,11 @@ function __init__()
253257
# "services/HeatmapSampler.jl",
254258
# )
255259

256-
# combining neural networks natively into the non-Gaussian factor graph object
257-
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
258-
include("Flux/FluxModelsDistribution.jl")
259-
include("Serialization/entities/FluxModelsSerialization.jl")
260-
include("Serialization/services/FluxModelsSerialization.jl") # uses BSON
261-
end
260+
# # combining neural networks natively into the non-Gaussian factor graph object
261+
# @require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
262+
# # include("Flux/FluxModelsDistribution.jl")
263+
# include("Serialization/services/FluxModelsSerialization.jl") # uses BSON
264+
# end
262265
end
263266

264267
@compile_workload begin

src/Serialization/entities/AdditionalDensities.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,3 @@ Base.@kwdef struct PackedAliasingScalarSampler <: PackedSamplableBelief
1313
domain::Vector{Float64} = [0; 1.0]
1414
weights::Vector{Float64} = [0.5; 0.5]
1515
end
16-
17-
Base.@kwdef mutable struct PackedHeatmapGridDensity <: PackedSamplableBelief
18-
_type::String = "IncrementalInference.PackedHeatmapGridDensity"
19-
data::Vector{Vector{Float64}}
20-
domain::Tuple{Vector{Float64}, Vector{Float64}}
21-
hint_callback::String
22-
bw_factor::Float64
23-
N::Int
24-
# _densityFnc::String = "" # only use if storing parched belief data entry label/id
25-
end
26-
27-
Base.@kwdef mutable struct PackedLevelSetGridNormal <: PackedSamplableBelief
28-
_type::String = "IncrementalInference.PackedLevelSetGridNormal"
29-
level::Float64
30-
sigma::Float64
31-
sigma_scale::Float64
32-
# make sure the JSON nested packing works with the serialization overlords
33-
heatmap::PackedHeatmapGridDensity
34-
end

src/Serialization/entities/SerializingOptionalDensities.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11

2-
export PackedFluxModelsDistribution
2+
3+
4+
Base.@kwdef mutable struct PackedHeatmapGridDensity <: PackedSamplableBelief
5+
_type::String = "IncrementalInference.PackedHeatmapGridDensity"
6+
data::Vector{Vector{Float64}}
7+
domain::Tuple{Vector{Float64}, Vector{Float64}}
8+
hint_callback::String
9+
bw_factor::Float64
10+
N::Int
11+
# _densityFnc::String = "" # only use if storing parched belief data entry label/id
12+
end
13+
14+
15+
Base.@kwdef mutable struct PackedLevelSetGridNormal <: PackedSamplableBelief
16+
_type::String = "IncrementalInference.PackedLevelSetGridNormal"
17+
level::Float64
18+
sigma::Float64
19+
sigma_scale::Float64
20+
# make sure the JSON nested packing works with the serialization overlords
21+
heatmap::PackedHeatmapGridDensity
22+
end
23+
324

425
Base.@kwdef mutable struct PackedFluxModelsDistribution <: PackedSamplableBelief
526
# standardized _type field

0 commit comments

Comments
 (0)