Skip to content

Commit ab63b31

Browse files
authored
Merge pull request #1749 from JuliaRobotics/23Q3/ext/flux
weakdeps FluxFactorsExt, drop Requires
2 parents 4918b57 + f17921c commit ab63b31

17 files changed

+388
-421
lines changed

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
3737
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3838
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3939
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
40-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
4140
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4241
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
4342
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -51,12 +50,14 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
5150

5251
[weakdeps]
5352
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
53+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5454
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
5555
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
5656
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
5757

5858
[extensions]
5959
DiffEqFactorExt = "DifferentialEquations"
60+
FluxFactorsExt = "Flux"
6061
GadflyExt = "Gadfly"
6162
InteractiveUtilsExt = "InteractiveUtils"
6263
InterpolationsExt = "Interpolations"
@@ -88,17 +89,15 @@ PrecompileTools = "1"
8889
ProgressMeter = "1"
8990
RecursiveArrayTools = "2.31.1"
9091
Reexport = "1"
91-
Requires = "1"
9292
SparseDiffTools = "2"
9393
StaticArrays = "1"
9494
StatsBase = "0.32, 0.33, 0.34"
9595
StructTypes = "1"
9696
TensorCast = "0.3.3, 0.4"
9797
TimeZones = "1.3.1"
98-
julia = "1.8"
98+
julia = "1.9"
9999

100100
[extras]
101-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
102101
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
103102
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
104103
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"

ext/DiffEqFactorExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
111111
return sol
112112
end
113113

114-
getSample(cf::CalcFactor{<:DERelative}) = error("getSample(::CalcFactor{<:DERelative}) not implemented yet")
114+
getSample(cf::CalcFactor{<:DERelative}) = error("getSample(::CalcFactor{<:DERelative}) must still be implemented in new IIF design")
115115

116116
# FIXME see #1025, `multihypo=` will not work properly yet
117117
function sampleFactor(cf::CalcFactor{<:DERelative}, N::Int = 1)

src/Flux/FluxModelsDistribution.jl renamed to ext/FluxFactorsExt.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
1+
module FluxFactorsExt
12

2-
@info "IncrementalInference is adding Flux related functionality."
33

4-
# the factor definitions
5-
export FluxModelsDistribution
6-
export MixtureFluxModels
4+
@info "IncrementalInference is loading extension functionality related to Flux.jl"
75

86
# Required packages
9-
using .Flux
7+
using Flux
108
using DataStructures: OrderedDict
11-
using Random, Statistics
9+
using LinearAlgebra
10+
using Base64
11+
using Manifolds
12+
using DocStringExtensions
13+
using BSON
14+
15+
import Base: convert
1216

1317
# import Base: convert
18+
using Random, Statistics
1419
import Random: rand
1520

21+
using IncrementalInference
22+
import IncrementalInference: samplePoint, sampleTangent, MixtureFluxModels, getSample
23+
24+
# the factor definitions
25+
# export FluxModelsDistribution
26+
export MixtureFluxModels
27+
1628
const _IIFListTypes = Union{<:AbstractVector, <:Tuple, <:NTuple, <:NamedTuple}
1729

1830
function Random.rand(nfb::FluxModelsDistribution, N::Integer = 1)
@@ -164,3 +176,8 @@ function MixtureFluxModels(::Type{F}, w...; kw...) where {F <: AbstractFactor}
164176
end
165177

166178
#
179+
180+
include("FluxModelsSerialization.jl")
181+
182+
183+
end # module

src/Serialization/services/FluxModelsSerialization.jl renamed to ext/FluxModelsSerialization.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
# @info "IncrementalInference is adding Flux/BSON serialization functionality."
44

5-
using Base64
6-
7-
import Base: convert
85

96
function _serializeFluxModelBase64(model::Flux.Chain)
107
io = IOBuffer()

ext/HeatmapSampler.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,110 @@
11
# heatmap sampler (experimental)
22

3+
4+
(hmd::HeatmapGridDensity)(w...; kw...) = hmd.densityFnc(w...; kw...)
5+
6+
function sampleTangent(M::AbstractManifold, hms::HeatmapGridDensity)
7+
return sampleTangent(M, hms.densityFnc)
8+
end
9+
10+
function Base.show(io::IO, x::HeatmapGridDensity{T, H, B}) where {T, H, B}
11+
printstyled(io, "HeatmapGridDensity{"; bold = true, color = :blue)
12+
println(io)
13+
printstyled(io, " T"; color = :magenta, bold = true)
14+
println(io, " = ", T)
15+
printstyled(io, " H"; color = :magenta, bold = true)
16+
println(io, "`int = ", H)
17+
printstyled(io, " B"; color = :magenta, bold = true)
18+
println(io, " = ", B)
19+
printstyled(io, " }"; color = :blue, bold = true)
20+
println(io, "(")
21+
println(io, " data: ", size(x.data))
22+
println(
23+
io,
24+
" min/max: ",
25+
round(minimum(x.data); digits = 5),
26+
" / ",
27+
round(maximum(x.data); digits = 5),
28+
)
29+
println(io, " domain: ", size(x.domain[1]), ", ", size(x.domain[2]))
30+
println(
31+
io,
32+
" min/max: ",
33+
round(minimum(x.domain[1]); digits = 5),
34+
" / ",
35+
round(maximum(x.domain[1]); digits = 5),
36+
)
37+
println(
38+
io,
39+
" min/max: ",
40+
round(minimum(x.domain[2]); digits = 5),
41+
" / ",
42+
round(maximum(x.domain[2]); digits = 5),
43+
)
44+
println(io, " bw_factor: ", x.bw_factor)
45+
print(io, " ")
46+
show(io, x.densityFnc)
47+
return nothing
48+
end
49+
50+
Base.show(io::IO, ::MIME"text/plain", x::HeatmapGridDensity) = show(io, x)
51+
Base.show(io::IO, ::MIME"application/prs.juno.inline", x::HeatmapGridDensity) = show(io, x)
52+
53+
"""
54+
$SIGNATURES
55+
56+
Internal function for updating HGD.
57+
58+
Notes
59+
- Likely to be used for [unstashing packed factors](@ref section_stash_unstash) via [`preambleCache`](@ref).
60+
- Counterpart to `AMP._update!` function for stashing of either MKD or HGD.
61+
"""
62+
function _update!(
63+
dst::HeatmapGridDensity{T, H, B},
64+
src::HeatmapGridDensity{T, H, B},
65+
) where {T, H, B}
66+
@assert size(dst.data) == size(src.data) "Updating HeatmapDensityGrid can only be done for data of the same size"
67+
dst.data .= src.data
68+
if !isapprox(dst.domain[1], src.domain[1])
69+
dst.domain[1] .= src.domain[1]
70+
end
71+
if !isapprox(dst.domain[2], src.domain[2])
72+
dst.domain[2] .= src.domain[2]
73+
end
74+
AMP._update!(dst.densityFnc, src.densityFnc)
75+
return dst
76+
end
77+
78+
79+
##
80+
81+
(lsg::LevelSetGridNormal)(w...; kw...) = lsg.densityFnc(w...; kw...)
82+
83+
function sampleTangent(M::AbstractManifold, lsg::LevelSetGridNormal)
84+
return sampleTangent(M, lsg.heatmap.densityFnc)
85+
end
86+
87+
function Base.show(io::IO, x::LevelSetGridNormal{T, H}) where {T, H}
88+
printstyled(io, "LevelSetGridNormal{"; bold = true, color = :blue)
89+
println(io)
90+
printstyled(io, " T"; color = :magenta, bold = true)
91+
println(io, " = ", T)
92+
printstyled(io, " H"; color = :magenta, bold = true)
93+
println(io, "`int = ", H)
94+
printstyled(io, " }"; color = :blue, bold = true)
95+
println(io, "(")
96+
println(io, " level: ", x.level)
97+
println(io, " sigma: ", x.sigma)
98+
println(io, " sig.scale: ", x.sigma_scale)
99+
println(io, " heatmap: ")
100+
show(io, x.heatmap)
101+
return nothing
102+
end
103+
104+
Base.show(io::IO, ::MIME"text/plain", x::LevelSetGridNormal) = show(io, x)
105+
Base.show(io::IO, ::MIME"application/prs.juno.inline", x::LevelSetGridNormal) = show(io, x)
106+
107+
3108
##
4109

5110
getManifold(hgd::HeatmapGridDensity) = getManifold(hgd.densityFnc)

ext/InterpolationsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using ApproxManifoldProducts
1111
import ApproxManifoldProducts: sample
1212
const AMP = ApproxManifoldProducts
1313

14-
import IncrementalInference: getManifold
14+
import Base: show
15+
import IncrementalInference: getManifold, sampleTangent
1516
import IncrementalInference: HeatmapGridDensity, PackedHeatmapGridDensity
1617
import IncrementalInference: LevelSetGridNormal, PackedLevelSetGridNormal
1718

ext/WeakDepsPrototypes.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,6 @@
11

2-
"""
3-
$TYPEDEF
4-
5-
Build a full ODE solution into a relative factor to condense possible sensor data into a relative transformation,
6-
but keeping the parameter estimation process fluid. Assumes first and second variable in order
7-
are of same dimension and compatible manifolds, such that ODE runs from Xi to Xi+1 on all
8-
dimensions. Internal state vector can be decoupled onto different domain as needed.
9-
10-
Notes
11-
- Based on DifferentialEquations.jl
12-
- `getSample` step does the `solve(ODEProblem)` step.
13-
- `tspan` is taken from variables only once at object construction -- i.e. won't detect changed timestamps.
14-
- Regular factor evaluation is done as full dimension `AbstractRelativeRoots`, and is basic linear difference.
15-
16-
DevNotes
17-
- FIXME see 1025, `multihypo=` will not yet work.
18-
- FIXME Lots of consolidation and standardization to do, see RoME.jl #244 regarding Manifolds.jl.
19-
- TODO does not yet handle case where a factor spans across two timezones.
20-
"""
21-
struct DERelative{T <: InferenceVariable, P, D} <: AbstractRelativeMinimize
22-
domain::Type{T}
23-
forwardProblem::P
24-
backwardProblem::P
25-
# second element of this data tuple is additional variables that will be passed down as a parameter
26-
data::D
27-
specialSampler::Function
28-
end
2+
# Flux.jl
3+
function MixtureFluxModels end
294

305
# InteractiveUtils.jl
316
function getCurrentWorkspaceFactors end

0 commit comments

Comments
 (0)