@@ -21,12 +21,13 @@ Returns a ComponentArray with underlying data `v`.
2121"""
2222function as_ca end
2323
24- function Base. length (cai:: AbstractComponentArrayInterpreter )
24+ function Base. length (cai:: AbstractComponentArrayInterpreter )
2525 prod (_axis_length .(CA. getaxes (cai)))
2626end
2727
28-
29- (interpreter:: AbstractComponentArrayInterpreter )(v:: AbstractArray ) = as_ca (v, interpreter)
28+ function (interpreter:: AbstractComponentArrayInterpreter )(v:: AbstractArray{ET} ) where ET
29+ as_ca (v, interpreter):: CA.ComponentArray{ET}
30+ end
3031
3132"""
3233Concrete version of `AbstractComponentArrayInterpreter` that stores an axis
@@ -39,11 +40,35 @@ Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
3940performance-critical functions.
4041"""
4142struct StaticComponentArrayInterpreter{AX} <: AbstractComponentArrayInterpreter end
42- function as_ca (v:: AbstractArray , :: StaticComponentArrayInterpreter{AX} ) where {AX}
43+ function as_ca (v:: AbstractArray , :: StaticComponentArrayInterpreter{AX} ) where {AX}
4344 vr = reshape (v, _axis_length .(AX))
44- CA. ComponentArray (vr, AX)
45+ CA. ComponentArray (vr, AX):: CA.ComponentArray{eltype(v)}
4546end
4647
48+ function StaticComponentArrayInterpreter (component_shapes:: NamedTuple )
49+ axs = map (component_shapes) do valx
50+ x = _val_value (valx)
51+ ax = x isa Integer ? CA. Shaped1DAxis ((x,)) : CA. ShapedAxis (x)
52+ (ax,)
53+ end
54+ axc = compose_axes (axs)
55+ StaticComponentArrayInterpreter {(axc,)} ()
56+ end
57+ function StaticComponentArrayInterpreter (ca:: CA.ComponentArray )
58+ ax = CA. getaxes (ca)
59+ StaticComponentArrayInterpreter {ax} ()
60+ end
61+
62+ # concatenate from several other ArrayInterpreters, keep static
63+ # did not manage to get it inferred, better use get_concrete(ComponentArrayInterpreter)
64+ # also does not save allocations
65+ # function StaticComponentArrayInterpreter(; kwargs...)
66+ # ints = values(kwargs)
67+ # axc = compose_axes(ints)
68+ # intc = StaticComponentArrayInterpreter{(axc,)}()
69+ # return(intc)
70+ # end
71+
4772# function Base.length(::StaticComponentArrayInterpreter{AX}) where {AX}
4873# #sum(length, typeof(AX).parameters[1])
4974# prod(_axis_length.(AX))
5580
5681get_concrete (cai:: StaticComponentArrayInterpreter ) = cai
5782
58-
5983"""
6084Non-Concrete version of `AbstractComponentArrayInterpreter` that avoids storing
6185additional type parameters.
@@ -66,23 +90,21 @@ not allow compiler-inferred `length` to construct StaticArrays.
6690Use `get_concrete(cai::ComponentArrayInterpreter)` to pass a concrete version to
6791performance-critical functions.
6892"""
69- struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
93+ struct ComponentArrayInterpreter <: AbstractComponentArrayInterpreter
7094 axes:: Tuple # {T, <:CA.AbstractAxis}
7195end
7296
73- function as_ca (v:: AbstractArray , cai:: ComponentArrayInterpreter )
74- vr = reshape (v , _axis_length .(cai. axes))
75- CA. ComponentArray (vr, cai. axes)
97+ function as_ca (v:: AbstractArray , cai:: ComponentArrayInterpreter )
98+ vr = reshape (CA . getdata (v) , _axis_length .(cai. axes))
99+ CA. ComponentArray (vr, cai. axes):: CA.ComponentArray{eltype(v)}
76100end
77101
78- function CA. getaxes (cai:: ComponentArrayInterpreter )
102+ function CA. getaxes (cai:: ComponentArrayInterpreter )
79103 cai. axes
80104end
81105
82-
83106get_concrete (cai:: ComponentArrayInterpreter ) = StaticComponentArrayInterpreter {cai.axes} ()
84107
85-
86108"""
87109 ComponentArrayInterpreter(; kwargs...)
88110 ComponentArrayInterpreter(::AbstractComponentArray)
@@ -108,71 +130,116 @@ The other constructors allow constructing arrays with additional dimensions.
108130"""
109131function ComponentArrayInterpreter (; kwargs... )
110132 ComponentArrayInterpreter (values (kwargs))
111- end ,
133+ end
112134function ComponentArrayInterpreter (component_shapes:: NamedTuple )
113- component_counts = map (prod, component_shapes)
114- n = sum (component_counts)
115- x = 1 : n
116- is_end = cumsum (component_counts)
117- is_start = (0 , is_end[1 : (end - 1 )]. .. ) .+ 1
118- # g = (x[i_start:i_end] for (i_start, i_end) in zip(is_start, is_end))
119- g = (reshape (x[i_start: i_end], shape) for (i_start, i_end, shape) in zip (is_start, is_end, component_shapes))
120- xc = CA. ComponentVector (; zip (propertynames (component_counts), g)... )
121- ComponentArrayInterpreter (xc)
135+ # component_counts = map(prod, component_shapes)
136+ # avoid constructing a template first, but create axes
137+ # n = sum(component_counts)
138+ # x = 1:n
139+ # is_end = cumsum(component_counts)
140+ # #is_start = (0, is_end[1:(end-1)]...) .+ 1 # problems with Zygote
141+ # is_start = Iterators.flatten((1:1, is_end[1:(end-1)] .+ 1))
142+ # g = (reshape(x[i_start:i_end], shape) for (i_start, i_end, shape) in zip(is_start, is_end, component_shapes))
143+ # xc = CA.ComponentVector(; zip(propertynames(component_counts), g)...)
144+ # #nt = NamedTuple{propertynames(component_counts)}(g)
145+ # ComponentArrayInterpreter(xc)
146+ axs = map (x -> (x isa Integer ? CA. Shaped1DAxis ((x,)) : CA. ShapedAxis (x),), component_shapes)
147+ ax = compose_axes (axs)
148+ m1 = ComponentArrayInterpreter ((ax,))
122149end
123150
124151function ComponentArrayInterpreter (vc:: CA.AbstractComponentArray )
125152 ComponentArrayInterpreter (CA. getaxes (vc))
126153end
127154
128-
129-
130155# Attach axes to matrices and arrays of ComponentArrays
131156# with ComponentArrays in the first dimensions (e.g. rownames of a matrix or array)
132157function ComponentArrayInterpreter (
133- ca:: CA.AbstractComponentArray , n_dims:: NTuple{N,<:Integer} ) where N
158+ ca:: CA.AbstractComponentArray , n_dims:: NTuple{N,<:Integer} ) where {N}
134159 ComponentArrayInterpreter (CA. getaxes (ca), n_dims)
135160end
136161function ComponentArrayInterpreter (
137- cai:: AbstractComponentArrayInterpreter , n_dims:: NTuple{N,<:Integer} ) where N
162+ cai:: AbstractComponentArrayInterpreter , n_dims:: NTuple{N,<:Integer} ) where {N}
138163 ComponentArrayInterpreter (CA. getaxes (cai), n_dims)
139164end
140165function ComponentArrayInterpreter (
141- axes:: NTuple{M, <:CA.AbstractAxis} , n_dims:: NTuple{N,<:Integer} ) where {M,N}
166+ axes:: NTuple{M,<:CA.AbstractAxis} , n_dims:: NTuple{N,<:Integer} ) where {M,N}
142167 axes_ext = (axes... , map (n_dim -> CA. Axis (i= 1 : n_dim), n_dims)... )
143168 ComponentArrayInterpreter (axes_ext)
144169end
145170
171+ # support also for other AbstractComponentArrayInterpreter types
172+ # in a type-stable way by providing the Tuple of dimensions as a value type
173+ """
174+ stack_ca_int(cai::AbstractComponentArrayInterpreter, ::Val{n_dims})
175+
176+ Interpret the first dimension of an Array as a ComponentArray. Provide the Tuple
177+ of following dimensions by a value type, e.g. `Val((n_col, n_z))`.
178+ """
179+ function stack_ca_int (
180+ cai:: IT , :: Val{n_dims} ) where {IT<: AbstractComponentArrayInterpreter ,n_dims}
181+ @assert n_dims isa NTuple{N,<: Integer } where {N}
182+ IT. name. wrapper (CA. getaxes (cai), n_dims):: IT.name.wrapper
183+ end
184+ function StaticComponentArrayInterpreter (
185+ axes:: NTuple{M,<:CA.AbstractAxis} , n_dims:: NTuple{N,<:Integer} ) where {M,N}
186+ axes_ext = (axes... , map (n_dim -> CA. Axis (i= 1 : n_dim), n_dims)... )
187+ StaticComponentArrayInterpreter {axes_ext} ()
188+ end
189+
146190# with ComponentArrays in the last dimensions (e.g. columnnames of a matrix)
147191function ComponentArrayInterpreter (
148- n_dims:: NTuple{N,<:Integer} , ca:: CA.AbstractComponentArray ) where N
192+ n_dims:: NTuple{N,<:Integer} , ca:: CA.AbstractComponentArray ) where {N}
149193 ComponentArrayInterpreter (n_dims, CA. getaxes (ca))
150194end
151195function ComponentArrayInterpreter (
152- n_dims:: NTuple{N,<:Integer} , cai:: AbstractComponentArrayInterpreter ) where N
196+ n_dims:: NTuple{N,<:Integer} , cai:: AbstractComponentArrayInterpreter ) where {N}
153197 ComponentArrayInterpreter (n_dims, CA. getaxes (cai))
154198end
155199function ComponentArrayInterpreter (
156- n_dims:: NTuple{N,<:Integer} , axes:: NTuple{M, <:CA.AbstractAxis} ) where {N,M}
200+ n_dims:: NTuple{N,<:Integer} , axes:: NTuple{M,<:CA.AbstractAxis} ) where {N,M}
157201 axes_ext = (map (n_dim -> CA. Axis (i= 1 : n_dim), n_dims)... , axes... )
158202 ComponentArrayInterpreter (axes_ext)
159203end
160204
205+ function stack_ca_int (
206+ :: Val{n_dims} , cai:: IT ) where {IT<: AbstractComponentArrayInterpreter ,n_dims}
207+ @assert n_dims isa NTuple{N,<: Integer } where {N}
208+ IT. name. wrapper (n_dims, CA. getaxes (cai)):: IT.name.wrapper
209+ end
210+ function StaticComponentArrayInterpreter (
211+ n_dims:: NTuple{N,<:Integer} , axes:: NTuple{M,<:CA.AbstractAxis} ) where {N,M}
212+ axes_ext = (map (n_dim -> CA. Axis (i= 1 : n_dim), n_dims)... , axes... )
213+ StaticComponentArrayInterpreter {axes_ext} ()
214+ end
215+
161216
162217# ambuiguity with two empty Tuples (edge prob that does not make sense)
163218# Empty ComponentVector with no other array dimensions -> empty componentVector
164219function ComponentArrayInterpreter (n_dims1:: Tuple{} , n_dims2:: Tuple{} )
165- ComponentArrayInterpreter (CA. ComponentVector ())
220+ ComponentArrayInterpreter ((CA. Axis (),))
221+ end
222+ function StaticComponentArrayInterpreter (n_dims1:: Tuple{} , n_dims2:: Tuple{} )
223+ StaticComponentArrayInterpreter {(CA.Axis(),)} ()
166224end
167225
226+ # concatenate several 1d ComponentArrayInterpreters
227+ function compose_interpreters (; kwargs... )
228+ compose_interpreters (values (kwargs))
229+ end
168230
231+ function compose_interpreters (ints:: NamedTuple )
232+ axtuples = map (x -> CA. getaxes (x), ints)
233+ axc = compose_axes (axtuples)
234+ intc = ComponentArrayInterpreter ((axc,))
235+ return (intc)
236+ end
169237
170238
171239# not exported, but required for testing
172240_get_ComponentArrayInterpreter_axes (:: StaticComponentArrayInterpreter{AX} ) where {AX} = AX
173241_get_ComponentArrayInterpreter_axes (cai:: ComponentArrayInterpreter ) = cai. axes
174242
175-
176243_axis_length (ax:: CA.AbstractAxis ) = lastindex (ax) - firstindex (ax) + 1
177244_axis_length (:: CA.FlatAxis ) = 0
178245_axis_length (:: CA.UnitRange ) = 0
@@ -199,15 +266,43 @@ function flatten1(cv::CA.ComponentVector)
199266 end
200267end
201268
202-
203269"""
204270 get_positions(cai::AbstractComponentArrayInterpreter)
205271
206272Create a NamedTuple of integer indices for each component.
207273Assumes that interpreter results in a one-dimensional array, i.e. in a ComponentVector.
208274"""
209275function get_positions (cai:: AbstractComponentArrayInterpreter )
210- @assert length (CA. getaxes (cai)) == 1
276+ # @assert length(CA.getaxes(cai)) == 1
211277 cv = cai (1 : length (cai))
212- (; (k => cv[k] for k in keys (cv)). .. )
278+ keys_cv = keys (cv)
279+ # splatting creates Problems with Zygote
280+ # keys_cv isa Tuple ? (; (k => CA.getdata(cv[k]) for k in keys_cv)...) : CA.getdata(cv)
281+ keys_cv isa Tuple ? NamedTuple {keys_cv} (map (k -> CA. getdata (cv[k]), keys_cv)) : CA. getdata (cv)
282+ end
283+
284+ function tmpf (v;
285+ cv,
286+ cai:: AbstractComponentArrayInterpreter = get_concrete (ComponentArrayInterpreter (cv)))
287+ cai (v)
288+ end
289+
290+ function tmpf1 (v; cai)
291+ caic = get_concrete (cai)
292+ # caic(v)
293+ Test. @inferred tmpf (v, cv= nothing , cai= caic)
294+ end
295+
296+ function tmpf2 (v; cai:: AbstractComponentArrayInterpreter )
297+ caic = get_concrete (cai)
298+ # caic = cai
299+ cv = Test. @inferred caic (v) # inferred inside tmpf2
300+ # cv = caic(v) # inferred inside tmpf2
301+ vv = tmpf (v; cv= nothing , cai= caic)
302+ # vv = tmpf(v; cv)
303+ # cv.x
304+ # sum(cv) # not inferred on Union cv (axis not know)
305+ # cv.x::AbstractVector{eltype(vv)} # not sufficient
306+ # need to specify concrete return type, but can rely on eltype
307+ sum (vv):: eltype (vv) # need to specify return type
213308end
0 commit comments