Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down Expand Up @@ -60,6 +61,7 @@ StableRNGs = "1.0.2"
StaticArrays = "1.9.13"
StatsBase = "0.34.4"
StatsFuns = "1.3.2"
Test = "1.10"
julia = "1.10"

[workspace]
Expand Down
407 changes: 185 additions & 222 deletions dev/doubleMM.jl

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions ext/HybridVariationalInferenceCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ function uutri2vec_gpu!(v::Union{CUDA.CuVector,CUDA.CuDeviceVector}, X::Abstract
return nothing # important
end

function HVI._create_random(rng, ::CUDA.CuVector{T}, dims...) where {T}
function HVI._create_randn(rng, v::CUDA.CuVector{T,M}, dims...) where {T,M}
# ignores rng
# https://discourse.julialang.org/t/help-using-cuda-zygote-and-random-numbers/123458/4?u=bgctw
ChainRulesCore.@ignore_derivatives CUDA.randn(dims...)
res = ChainRulesCore.@ignore_derivatives CUDA.randn(dims...)
res::CUDA.CuArray{T, length(dims),M}
end


Expand Down
7 changes: 4 additions & 3 deletions ext/HybridVariationalInferenceFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ end

function HVI.apply_model(app::FluxApplicator, x, ϕ)
m = app.rebuild(ϕ)
m(x)
res = m(x)
res
end

# struct FluxGPUDataHandler <: AbstractGPUDataHandler end
Expand All @@ -38,15 +39,15 @@ end

function HVI.construct_3layer_MLApplicator(
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux};
scenario::NTuple = ())
scenario::Val{scen}) where scen
(;θM) = get_hybridproblem_par_templates(prob; scenario)
n_out = length(θM)
n_covar = get_hybridproblem_n_covar(prob; scenario)
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
n_input = n_covar + n_pbm_covars
#(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario)
float_type = get_hybridproblem_float_type(prob; scenario)
is_using_dropout = :use_dropout ∈ scenario
is_using_dropout = :use_dropout ∈ scen
is_using_dropout && error("dropout scenario not supported with Flux yet.")
g_chain = Flux.Chain(
# dense layer with bias that maps to 8 outputs and applies `tanh` activation
Expand Down
4 changes: 2 additions & 2 deletions ext/HybridVariationalInferenceSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)

function HVI.construct_3layer_MLApplicator(
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains};
scenario::NTuple = ())
scenario::Val{scen}) where scen
n_covar = get_hybridproblem_n_covar(prob; scenario)
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
n_input = n_covar + n_pbm_covars
FloatType = get_hybridproblem_float_type(prob; scenario)
(;θM) = get_hybridproblem_par_templates(prob; scenario)
n_out = length(θM)
is_using_dropout = :use_dropout ∈ scenario
is_using_dropout = :use_dropout ∈ scen
g_chain = if is_using_dropout
SimpleChain(
static(n_input), # input dimension (optional)
Expand Down
11 changes: 6 additions & 5 deletions src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ returns a Tuple of
"""
function get_hybridproblem_MLapplicator end

function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario = ())
function get_hybridproblem_MLapplicator(
prob::AbstractHybridProblem; scenario::Val{scen} = Val(())) where scen
get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario)
end

Expand Down Expand Up @@ -202,13 +203,13 @@ end
Put relevant parts of the DataLoader to gpu, depending on scenario.
"""
function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader;
scenario = (),
scenario::Val{scen} = Val(()),
gdev = gpu_device(),
gdev_M = :use_gpu ∈ scenario ? gdev : identity,
gdev_P = :f_on_gpu ∈ scenario ? gdev : identity,
gdev_M = :use_gpu ∈ _val_value(scenario) ? gdev : identity,
gdev_P = :f_on_gpu ∈ _val_value(scenario) ? gdev : identity,
batchsize = dataloader.batchsize,
partial = dataloader.partial
)
) where scen
xM, xP, y_o, y_unc, i_sites = dataloader.data
xM_dev = gdev_M(xM)
xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc))
Expand Down
167 changes: 131 additions & 36 deletions src/ComponentArrayInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ Returns a ComponentArray with underlying data `v`.
"""
function as_ca end

function Base.length(cai::AbstractComponentArrayInterpreter)
function Base.length(cai::AbstractComponentArrayInterpreter)
prod(_axis_length.(CA.getaxes(cai)))
end


(interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter)
function (interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray{ET}) where ET
as_ca(v, interpreter)::CA.ComponentArray{ET}
end

"""
Concrete version of `AbstractComponentArrayInterpreter` that stores an axis
Expand All @@ -39,11 +40,35 @@ Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
performance-critical functions.
"""
struct StaticComponentArrayInterpreter{AX} <: AbstractComponentArrayInterpreter end
function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX}
function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX}
vr = reshape(v, _axis_length.(AX))
CA.ComponentArray(vr, AX)
CA.ComponentArray(vr, AX)::CA.ComponentArray{eltype(v)}
end

function StaticComponentArrayInterpreter(component_shapes::NamedTuple)
axs = map(component_shapes) do valx
x = _val_value(valx)
ax = x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x)
(ax,)
end
axc = compose_axes(axs)
StaticComponentArrayInterpreter{(axc,)}()
end
function StaticComponentArrayInterpreter(ca::CA.ComponentArray)
ax = CA.getaxes(ca)
StaticComponentArrayInterpreter{ax}()
end

# concatenate from several other ArrayInterpreters, keep static
# did not manage to get it inferred, better use get_concrete(ComponentArrayInterpreter)
# also does not save allocations
# function StaticComponentArrayInterpreter(; kwargs...)
# ints = values(kwargs)
# axc = compose_axes(ints)
# intc = StaticComponentArrayInterpreter{(axc,)}()
# return(intc)
# end

# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
# #sum(length, typeof(AX).parameters[1])
# prod(_axis_length.(AX))
Expand All @@ -55,7 +80,6 @@ end

get_concrete(cai::StaticComponentArrayInterpreter) = cai


"""
Non-Concrete version of `AbstractComponentArrayInterpreter` that avoids storing
additional type parameters.
Expand All @@ -66,23 +90,21 @@ not allow compiler-inferred `length` to construct StaticArrays.
Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
performance-critical functions.
"""
struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
axes::Tuple #{T, <:CA.AbstractAxis}
end

function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
vr = reshape(v, _axis_length.(cai.axes))
CA.ComponentArray(vr, cai.axes)
function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
vr = reshape(CA.getdata(v), _axis_length.(cai.axes))
CA.ComponentArray(vr, cai.axes)::CA.ComponentArray{eltype(v)}
end

function CA.getaxes(cai::ComponentArrayInterpreter)
function CA.getaxes(cai::ComponentArrayInterpreter)
cai.axes
end


get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}()


"""
ComponentArrayInterpreter(; kwargs...)
ComponentArrayInterpreter(::AbstractComponentArray)
Expand All @@ -108,71 +130,116 @@ The other constructors allow constructing arrays with additional dimensions.
"""
function ComponentArrayInterpreter(; kwargs...)
ComponentArrayInterpreter(values(kwargs))
end,
end
function ComponentArrayInterpreter(component_shapes::NamedTuple)
component_counts = map(prod, component_shapes)
n = sum(component_counts)
x = 1:n
is_end = cumsum(component_counts)
is_start = (0, is_end[1:(end-1)]...) .+ 1
#g = (x[i_start:i_end] for (i_start, i_end) in zip(is_start, is_end))
g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes))
xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...)
ComponentArrayInterpreter(xc)
#component_counts = map(prod, component_shapes)
# avoid constructing a template first, but create axes
# n = sum(component_counts)
# x = 1:n
# is_end = cumsum(component_counts)
# #is_start = (0, is_end[1:(end-1)]...) .+ 1 # problems with Zygote
# is_start = Iterators.flatten((1:1, is_end[1:(end-1)] .+ 1))
# g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes))
# xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...)
# #nt = NamedTuple{propertynames(component_counts)}(g)
# ComponentArrayInterpreter(xc)
axs = map(x -> (x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x),), component_shapes)
ax = compose_axes(axs)
m1 = ComponentArrayInterpreter((ax,))
end

function ComponentArrayInterpreter(vc::CA.AbstractComponentArray)
ComponentArrayInterpreter(CA.getaxes(vc))
end



# Attach axes to matrices and arrays of ComponentArrays
# with ComponentArrays in the first dimensions (e.g. rownames of a matrix or array)
function ComponentArrayInterpreter(
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where {N}
ComponentArrayInterpreter(CA.getaxes(ca), n_dims)
end
function ComponentArrayInterpreter(
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where {N}
ComponentArrayInterpreter(CA.getaxes(cai), n_dims)
end
function ComponentArrayInterpreter(
axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
ComponentArrayInterpreter(axes_ext)
end

# support also for other AbstractComponentArrayInterpreter types
# in a type-stable way by providing the Tuple of dimensions as a value type
"""
stack_ca_int(cai::AbstractComponentArrayInterpreter, ::Val{n_dims})

Interpret the first dimension of an Array as a ComponentArray. Provide the Tuple
of following dimensions by a value type, e.g. `Val((n_col, n_z))`.
"""
function stack_ca_int(
cai::IT, ::Val{n_dims}) where {IT<:AbstractComponentArrayInterpreter,n_dims}
@assert n_dims isa NTuple{N,<:Integer} where {N}
IT.name.wrapper(CA.getaxes(cai), n_dims)::IT.name.wrapper
end
function StaticComponentArrayInterpreter(
axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
StaticComponentArrayInterpreter{axes_ext}()
end

# with ComponentArrays in the last dimensions (e.g. columnnames of a matrix)
function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where {N}
ComponentArrayInterpreter(n_dims, CA.getaxes(ca))
end
function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where {N}
ComponentArrayInterpreter(n_dims, CA.getaxes(cai))
end
function ComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M}
n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M}
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
ComponentArrayInterpreter(axes_ext)
end

function stack_ca_int(
::Val{n_dims}, cai::IT) where {IT<:AbstractComponentArrayInterpreter,n_dims}
@assert n_dims isa NTuple{N,<:Integer} where {N}
IT.name.wrapper(n_dims, CA.getaxes(cai))::IT.name.wrapper
end
function StaticComponentArrayInterpreter(
n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M}
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
StaticComponentArrayInterpreter{axes_ext}()
end


# ambuiguity with two empty Tuples (edge prob that does not make sense)
# Empty ComponentVector with no other array dimensions -> empty componentVector
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
ComponentArrayInterpreter(CA.ComponentVector())
ComponentArrayInterpreter((CA.Axis(),))
end
function StaticComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
StaticComponentArrayInterpreter{(CA.Axis(),)}()
end

# concatenate several 1d ComponentArrayInterpreters
function compose_interpreters(; kwargs...)
compose_interpreters(values(kwargs))
end

function compose_interpreters(ints::NamedTuple)
axtuples = map(x -> CA.getaxes(x), ints)
axc = compose_axes(axtuples)
intc = ComponentArrayInterpreter((axc,))
return (intc)
end


# not exported, but required for testing
_get_ComponentArrayInterpreter_axes(::StaticComponentArrayInterpreter{AX}) where {AX} = AX
_get_ComponentArrayInterpreter_axes(cai::ComponentArrayInterpreter) = cai.axes


_axis_length(ax::CA.AbstractAxis) = lastindex(ax) - firstindex(ax) + 1
_axis_length(::CA.FlatAxis) = 0
_axis_length(::CA.UnitRange) = 0
Expand All @@ -199,15 +266,43 @@ function flatten1(cv::CA.ComponentVector)
end
end


"""
get_positions(cai::AbstractComponentArrayInterpreter)

Create a NamedTuple of integer indices for each component.
Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector.
"""
function get_positions(cai::AbstractComponentArrayInterpreter)
@assert length(CA.getaxes(cai)) == 1
#@assert length(CA.getaxes(cai)) == 1
cv = cai(1:length(cai))
(; (k => cv[k] for k in keys(cv))... )
keys_cv = keys(cv)
# splatting creates Problems with Zygote
#keys_cv isa Tuple ? (; (k => CA.getdata(cv[k]) for k in keys_cv)...) : CA.getdata(cv)
keys_cv isa Tuple ? NamedTuple{keys_cv}(map(k -> CA.getdata(cv[k]), keys_cv)) : CA.getdata(cv)
end

function tmpf(v;
cv,
cai::AbstractComponentArrayInterpreter=get_concrete(ComponentArrayInterpreter(cv)))
cai(v)
end

function tmpf1(v; cai)
caic = get_concrete(cai)
#caic(v)
Test.@inferred tmpf(v, cv=nothing, cai=caic)
end

function tmpf2(v; cai::AbstractComponentArrayInterpreter)
caic = get_concrete(cai)
#caic = cai
cv = Test.@inferred caic(v) # inferred inside tmpf2
#cv = caic(v) # inferred inside tmpf2
vv = tmpf(v; cv=nothing, cai=caic)
#vv = tmpf(v; cv)
#cv.x
#sum(cv) # not inferred on Union cv (axis not know)
#cv.x::AbstractVector{eltype(vv)} # not sufficient
# need to specify concrete return type, but can rely on eltype
sum(vv)::eltype(vv) # need to specify return type
end
Loading
Loading