Skip to content

Commit 5b9565f

Browse files
authored
Return PBm parameters P and Ms separately (#22)
* implement extend_stacked_nrow * transposed zetas to speed up transformations * work on type stability * implement HybridProblemInterpreters ** that provide type-stable Interpreters of all variantes (n_site + n_batch / Ms + Mst ) * implement type-stable combine_axes so that gen generate variants based on ony axes of thetaP and thetaM (and n_sites_and_batch) * rewrite scenario to value-type to make ComponentArrayProblem type-stable * transform already sampled residuals to site-first form ** And sample zeta_resid in parameter-last form * fix error of sampling normal rather than uniform * return P and Ms parameters separately to communicate shape ** rather than fiddling with various ComponentArrayInterpreters. * implement apply_trans_f for single parameter vector * check type inference, adapt Solvers * test distribution of generated residuals * test scaling variance * update HybridProblem for narrower types
1 parent f4ef19d commit 5b9565f

31 files changed

+2351
-1013
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2323
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2424
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2525
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
26+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627

2728
[weakdeps]
2829
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -60,6 +61,7 @@ StableRNGs = "1.0.2"
6061
StaticArrays = "1.9.13"
6162
StatsBase = "0.34.4"
6263
StatsFuns = "1.3.2"
64+
Test = "1.10"
6365
julia = "1.10"
6466

6567
[workspace]

dev/doubleMM.jl

Lines changed: 185 additions & 222 deletions
Large diffs are not rendered by default.

ext/HybridVariationalInferenceCUDAExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ function uutri2vec_gpu!(v::Union{CUDA.CuVector,CUDA.CuDeviceVector}, X::Abstract
7777
return nothing # important
7878
end
7979

80-
function HVI._create_random(rng, ::CUDA.CuVector{T}, dims...) where {T}
80+
function HVI._create_randn(rng, v::CUDA.CuVector{T,M}, dims...) where {T,M}
8181
# ignores rng
8282
# https://discourse.julialang.org/t/help-using-cuda-zygote-and-random-numbers/123458/4?u=bgctw
83-
ChainRulesCore.@ignore_derivatives CUDA.randn(dims...)
83+
res = ChainRulesCore.@ignore_derivatives CUDA.randn(dims...)
84+
res::CUDA.CuArray{T, length(dims),M}
8485
end
8586

8687

ext/HybridVariationalInferenceFluxExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ end
1818

1919
function HVI.apply_model(app::FluxApplicator, x, ϕ)
2020
m = app.rebuild(ϕ)
21-
m(x)
21+
res = m(x)
22+
res
2223
end
2324

2425
# struct FluxGPUDataHandler <: AbstractGPUDataHandler end
@@ -38,15 +39,15 @@ end
3839

3940
function HVI.construct_3layer_MLApplicator(
4041
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:Flux};
41-
scenario::NTuple = ())
42+
scenario::Val{scen}) where scen
4243
(;θM) = get_hybridproblem_par_templates(prob; scenario)
4344
n_out = length(θM)
4445
n_covar = get_hybridproblem_n_covar(prob; scenario)
4546
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
4647
n_input = n_covar + n_pbm_covars
4748
#(; n_covar, n_θM) = get_hybridproblem_sizes(prob; scenario)
4849
float_type = get_hybridproblem_float_type(prob; scenario)
49-
is_using_dropout = :use_dropout scenario
50+
is_using_dropout = :use_dropout scen
5051
is_using_dropout && error("dropout scenario not supported with Flux yet.")
5152
g_chain = Flux.Chain(
5253
# dense layer with bias that maps to 8 outputs and applies `tanh` activation

ext/HybridVariationalInferenceSimpleChainsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ HVI.apply_model(app::SimpleChainsApplicator, x, ϕ) = app.m(x, ϕ)
1919

2020
function HVI.construct_3layer_MLApplicator(
2121
rng::AbstractRNG, prob::HVI.AbstractHybridProblem, ::Val{:SimpleChains};
22-
scenario::NTuple = ())
22+
scenario::Val{scen}) where scen
2323
n_covar = get_hybridproblem_n_covar(prob; scenario)
2424
n_pbm_covars = length(get_hybridproblem_pbmpar_covars(prob; scenario))
2525
n_input = n_covar + n_pbm_covars
2626
FloatType = get_hybridproblem_float_type(prob; scenario)
2727
(;θM) = get_hybridproblem_par_templates(prob; scenario)
2828
n_out = length(θM)
29-
is_using_dropout = :use_dropout scenario
29+
is_using_dropout = :use_dropout scen
3030
g_chain = if is_using_dropout
3131
SimpleChain(
3232
static(n_input), # input dimension (optional)

src/AbstractHybridProblem.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ returns a Tuple of
4040
"""
4141
function get_hybridproblem_MLapplicator end
4242

43-
function get_hybridproblem_MLapplicator(prob::AbstractHybridProblem; scenario = ())
43+
function get_hybridproblem_MLapplicator(
44+
prob::AbstractHybridProblem; scenario::Val{scen} = Val(())) where scen
4445
get_hybridproblem_MLapplicator(Random.default_rng(), prob; scenario)
4546
end
4647

@@ -202,13 +203,13 @@ end
202203
Put relevant parts of the DataLoader to gpu, depending on scenario.
203204
"""
204205
function gdev_hybridproblem_dataloader(dataloader::MLUtils.DataLoader;
205-
scenario = (),
206+
scenario::Val{scen} = Val(()),
206207
gdev = gpu_device(),
207-
gdev_M = :use_gpu scenario ? gdev : identity,
208-
gdev_P = :f_on_gpu scenario ? gdev : identity,
208+
gdev_M = :use_gpu _val_value(scenario) ? gdev : identity,
209+
gdev_P = :f_on_gpu _val_value(scenario) ? gdev : identity,
209210
batchsize = dataloader.batchsize,
210211
partial = dataloader.partial
211-
)
212+
) where scen
212213
xM, xP, y_o, y_unc, i_sites = dataloader.data
213214
xM_dev = gdev_M(xM)
214215
xP_dev, y_o_dev, y_unc_dev = (gdev_P(xP), gdev_P(y_o), gdev_P(y_unc))

src/ComponentArrayInterpreter.jl

Lines changed: 131 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ Returns a ComponentArray with underlying data `v`.
2121
"""
2222
function as_ca end
2323

24-
function Base.length(cai::AbstractComponentArrayInterpreter)
24+
function Base.length(cai::AbstractComponentArrayInterpreter)
2525
prod(_axis_length.(CA.getaxes(cai)))
2626
end
2727

28-
29-
(interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray) = as_ca(v, interpreter)
28+
function (interpreter::AbstractComponentArrayInterpreter)(v::AbstractArray{ET}) where ET
29+
as_ca(v, interpreter)::CA.ComponentArray{ET}
30+
end
3031

3132
"""
3233
Concrete version of `AbstractComponentArrayInterpreter` that stores an axis
@@ -39,11 +40,35 @@ Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
3940
performance-critical functions.
4041
"""
4142
struct StaticComponentArrayInterpreter{AX} <: AbstractComponentArrayInterpreter end
42-
function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX}
43+
function as_ca(v::AbstractArray, ::StaticComponentArrayInterpreter{AX}) where {AX}
4344
vr = reshape(v, _axis_length.(AX))
44-
CA.ComponentArray(vr, AX)
45+
CA.ComponentArray(vr, AX)::CA.ComponentArray{eltype(v)}
4546
end
4647

48+
function StaticComponentArrayInterpreter(component_shapes::NamedTuple)
49+
axs = map(component_shapes) do valx
50+
x = _val_value(valx)
51+
ax = x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x)
52+
(ax,)
53+
end
54+
axc = compose_axes(axs)
55+
StaticComponentArrayInterpreter{(axc,)}()
56+
end
57+
function StaticComponentArrayInterpreter(ca::CA.ComponentArray)
58+
ax = CA.getaxes(ca)
59+
StaticComponentArrayInterpreter{ax}()
60+
end
61+
62+
# concatenate from several other ArrayInterpreters, keep static
63+
# did not manage to get it inferred, better use get_concrete(ComponentArrayInterpreter)
64+
# also does not save allocations
65+
# function StaticComponentArrayInterpreter(; kwargs...)
66+
# ints = values(kwargs)
67+
# axc = compose_axes(ints)
68+
# intc = StaticComponentArrayInterpreter{(axc,)}()
69+
# return(intc)
70+
# end
71+
4772
# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
4873
# #sum(length, typeof(AX).parameters[1])
4974
# prod(_axis_length.(AX))
@@ -55,7 +80,6 @@ end
5580

5681
get_concrete(cai::StaticComponentArrayInterpreter) = cai
5782

58-
5983
"""
6084
Non-Concrete version of `AbstractComponentArrayInterpreter` that avoids storing
6185
additional type parameters.
@@ -66,23 +90,21 @@ not allow compiler-inferred `length` to construct StaticArrays.
6690
Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
6791
performance-critical functions.
6892
"""
69-
struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
93+
struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
7094
axes::Tuple #{T, <:CA.AbstractAxis}
7195
end
7296

73-
function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
74-
vr = reshape(v, _axis_length.(cai.axes))
75-
CA.ComponentArray(vr, cai.axes)
97+
function as_ca(v::AbstractArray, cai::ComponentArrayInterpreter)
98+
vr = reshape(CA.getdata(v), _axis_length.(cai.axes))
99+
CA.ComponentArray(vr, cai.axes)::CA.ComponentArray{eltype(v)}
76100
end
77101

78-
function CA.getaxes(cai::ComponentArrayInterpreter)
102+
function CA.getaxes(cai::ComponentArrayInterpreter)
79103
cai.axes
80104
end
81105

82-
83106
get_concrete(cai::ComponentArrayInterpreter) = StaticComponentArrayInterpreter{cai.axes}()
84107

85-
86108
"""
87109
ComponentArrayInterpreter(; kwargs...)
88110
ComponentArrayInterpreter(::AbstractComponentArray)
@@ -108,71 +130,116 @@ The other constructors allow constructing arrays with additional dimensions.
108130
"""
109131
function ComponentArrayInterpreter(; kwargs...)
110132
ComponentArrayInterpreter(values(kwargs))
111-
end,
133+
end
112134
function ComponentArrayInterpreter(component_shapes::NamedTuple)
113-
component_counts = map(prod, component_shapes)
114-
n = sum(component_counts)
115-
x = 1:n
116-
is_end = cumsum(component_counts)
117-
is_start = (0, is_end[1:(end-1)]...) .+ 1
118-
#g = (x[i_start:i_end] for (i_start, i_end) in zip(is_start, is_end))
119-
g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes))
120-
xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...)
121-
ComponentArrayInterpreter(xc)
135+
#component_counts = map(prod, component_shapes)
136+
# avoid constructing a template first, but create axes
137+
# n = sum(component_counts)
138+
# x = 1:n
139+
# is_end = cumsum(component_counts)
140+
# #is_start = (0, is_end[1:(end-1)]...) .+ 1 # problems with Zygote
141+
# is_start = Iterators.flatten((1:1, is_end[1:(end-1)] .+ 1))
142+
# g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes))
143+
# xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...)
144+
# #nt = NamedTuple{propertynames(component_counts)}(g)
145+
# ComponentArrayInterpreter(xc)
146+
axs = map(x -> (x isa Integer ? CA.Shaped1DAxis((x,)) : CA.ShapedAxis(x),), component_shapes)
147+
ax = compose_axes(axs)
148+
m1 = ComponentArrayInterpreter((ax,))
122149
end
123150

124151
function ComponentArrayInterpreter(vc::CA.AbstractComponentArray)
125152
ComponentArrayInterpreter(CA.getaxes(vc))
126153
end
127154

128-
129-
130155
# Attach axes to matrices and arrays of ComponentArrays
131156
# with ComponentArrays in the first dimensions (e.g. rownames of a matrix or array)
132157
function ComponentArrayInterpreter(
133-
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where N
158+
ca::CA.AbstractComponentArray, n_dims::NTuple{N,<:Integer}) where {N}
134159
ComponentArrayInterpreter(CA.getaxes(ca), n_dims)
135160
end
136161
function ComponentArrayInterpreter(
137-
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where N
162+
cai::AbstractComponentArrayInterpreter, n_dims::NTuple{N,<:Integer}) where {N}
138163
ComponentArrayInterpreter(CA.getaxes(cai), n_dims)
139164
end
140165
function ComponentArrayInterpreter(
141-
axes::NTuple{M, <:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
166+
axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
142167
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
143168
ComponentArrayInterpreter(axes_ext)
144169
end
145170

171+
# support also for other AbstractComponentArrayInterpreter types
172+
# in a type-stable way by providing the Tuple of dimensions as a value type
173+
"""
174+
stack_ca_int(cai::AbstractComponentArrayInterpreter, ::Val{n_dims})
175+
176+
Interpret the first dimension of an Array as a ComponentArray. Provide the Tuple
177+
of following dimensions by a value type, e.g. `Val((n_col, n_z))`.
178+
"""
179+
function stack_ca_int(
180+
cai::IT, ::Val{n_dims}) where {IT<:AbstractComponentArrayInterpreter,n_dims}
181+
@assert n_dims isa NTuple{N,<:Integer} where {N}
182+
IT.name.wrapper(CA.getaxes(cai), n_dims)::IT.name.wrapper
183+
end
184+
function StaticComponentArrayInterpreter(
185+
axes::NTuple{M,<:CA.AbstractAxis}, n_dims::NTuple{N,<:Integer}) where {M,N}
186+
axes_ext = (axes..., map(n_dim -> CA.Axis(i=1:n_dim), n_dims)...)
187+
StaticComponentArrayInterpreter{axes_ext}()
188+
end
189+
146190
# with ComponentArrays in the last dimensions (e.g. columnnames of a matrix)
147191
function ComponentArrayInterpreter(
148-
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where N
192+
n_dims::NTuple{N,<:Integer}, ca::CA.AbstractComponentArray) where {N}
149193
ComponentArrayInterpreter(n_dims, CA.getaxes(ca))
150194
end
151195
function ComponentArrayInterpreter(
152-
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where N
196+
n_dims::NTuple{N,<:Integer}, cai::AbstractComponentArrayInterpreter) where {N}
153197
ComponentArrayInterpreter(n_dims, CA.getaxes(cai))
154198
end
155199
function ComponentArrayInterpreter(
156-
n_dims::NTuple{N,<:Integer}, axes::NTuple{M, <:CA.AbstractAxis}) where {N,M}
200+
n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M}
157201
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
158202
ComponentArrayInterpreter(axes_ext)
159203
end
160204

205+
function stack_ca_int(
206+
::Val{n_dims}, cai::IT) where {IT<:AbstractComponentArrayInterpreter,n_dims}
207+
@assert n_dims isa NTuple{N,<:Integer} where {N}
208+
IT.name.wrapper(n_dims, CA.getaxes(cai))::IT.name.wrapper
209+
end
210+
function StaticComponentArrayInterpreter(
211+
n_dims::NTuple{N,<:Integer}, axes::NTuple{M,<:CA.AbstractAxis}) where {N,M}
212+
axes_ext = (map(n_dim -> CA.Axis(i=1:n_dim), n_dims)..., axes...)
213+
StaticComponentArrayInterpreter{axes_ext}()
214+
end
215+
161216

162217
# ambuiguity with two empty Tuples (edge prob that does not make sense)
163218
# Empty ComponentVector with no other array dimensions -> empty componentVector
164219
function ComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
165-
ComponentArrayInterpreter(CA.ComponentVector())
220+
ComponentArrayInterpreter((CA.Axis(),))
221+
end
222+
function StaticComponentArrayInterpreter(n_dims1::Tuple{}, n_dims2::Tuple{})
223+
StaticComponentArrayInterpreter{(CA.Axis(),)}()
166224
end
167225

226+
# concatenate several 1d ComponentArrayInterpreters
227+
function compose_interpreters(; kwargs...)
228+
compose_interpreters(values(kwargs))
229+
end
168230

231+
function compose_interpreters(ints::NamedTuple)
232+
axtuples = map(x -> CA.getaxes(x), ints)
233+
axc = compose_axes(axtuples)
234+
intc = ComponentArrayInterpreter((axc,))
235+
return (intc)
236+
end
169237

170238

171239
# not exported, but required for testing
172240
_get_ComponentArrayInterpreter_axes(::StaticComponentArrayInterpreter{AX}) where {AX} = AX
173241
_get_ComponentArrayInterpreter_axes(cai::ComponentArrayInterpreter) = cai.axes
174242

175-
176243
_axis_length(ax::CA.AbstractAxis) = lastindex(ax) - firstindex(ax) + 1
177244
_axis_length(::CA.FlatAxis) = 0
178245
_axis_length(::CA.UnitRange) = 0
@@ -199,15 +266,43 @@ function flatten1(cv::CA.ComponentVector)
199266
end
200267
end
201268

202-
203269
"""
204270
get_positions(cai::AbstractComponentArrayInterpreter)
205271
206272
Create a NamedTuple of integer indices for each component.
207273
Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector.
208274
"""
209275
function get_positions(cai::AbstractComponentArrayInterpreter)
210-
@assert length(CA.getaxes(cai)) == 1
276+
#@assert length(CA.getaxes(cai)) == 1
211277
cv = cai(1:length(cai))
212-
(; (k => cv[k] for k in keys(cv))... )
278+
keys_cv = keys(cv)
279+
# splatting creates Problems with Zygote
280+
#keys_cv isa Tuple ? (; (k => CA.getdata(cv[k]) for k in keys_cv)...) : CA.getdata(cv)
281+
keys_cv isa Tuple ? NamedTuple{keys_cv}(map(k -> CA.getdata(cv[k]), keys_cv)) : CA.getdata(cv)
282+
end
283+
284+
function tmpf(v;
285+
cv,
286+
cai::AbstractComponentArrayInterpreter=get_concrete(ComponentArrayInterpreter(cv)))
287+
cai(v)
288+
end
289+
290+
function tmpf1(v; cai)
291+
caic = get_concrete(cai)
292+
#caic(v)
293+
Test.@inferred tmpf(v, cv=nothing, cai=caic)
294+
end
295+
296+
function tmpf2(v; cai::AbstractComponentArrayInterpreter)
297+
caic = get_concrete(cai)
298+
#caic = cai
299+
cv = Test.@inferred caic(v) # inferred inside tmpf2
300+
#cv = caic(v) # inferred inside tmpf2
301+
vv = tmpf(v; cv=nothing, cai=caic)
302+
#vv = tmpf(v; cv)
303+
#cv.x
304+
#sum(cv) # not inferred on Union cv (axis not know)
305+
#cv.x::AbstractVector{eltype(vv)} # not sufficient
306+
# need to specify concrete return type, but can rely on eltype
307+
sum(vv)::eltype(vv) # need to specify return type
213308
end

0 commit comments

Comments
 (0)