11module LazyNamedDimsArrays
22
3+ using AbstractTrees: AbstractTrees
34using WrappedUnions: @wrapped , unwrap
45using NamedDimsArrays:
56 NamedDimsArrays,
67 AbstractNamedDimsArray,
78 AbstractNamedDimsArrayStyle,
9+ NamedDimsArray,
810 dename,
11+ dimnames,
912 inds
1013using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1114
15+ # Custom version of `AbstractTrees.printnode` to
16+ # avoid type piracy when overloading on `AbstractNamedDimsArray`.
17+ printnode (io:: IO , x) = AbstractTrees. printnode (io, x)
18+ function printnode (io:: IO , a:: AbstractNamedDimsArray )
19+ show (io, collect (dimnames (a)))
20+ return nothing
21+ end
22+
1223struct Mul{A}
1324 arguments:: Vector{A}
1425end
@@ -21,6 +32,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
2132TermInterface. operation (m:: Mul ) = *
2233TermInterface. sorted_arguments (m:: Mul ) = arguments (m)
2334TermInterface. sorted_children (m:: Mul ) = sorted_arguments (a)
35+ ismul (x) = false
36+ ismul (m:: Mul ) = true
37+ function Base. show (io:: IO , m:: Mul )
38+ args = map (arg -> sprint (printnode, arg), arguments (m))
39+ print (io, " (" , join (args, " $(operation (m)) " ), " )" )
40+ return nothing
41+ end
2442
2543@wrapped struct LazyNamedDimsArray{
2644 T, A <: AbstractNamedDimsArray{T} ,
3048
3149function NamedDimsArrays. inds (a:: LazyNamedDimsArray )
3250 u = unwrap (a)
33- if u isa AbstractNamedDimsArray
51+ if ! iscall (u)
3452 return inds (u)
35- elseif u isa Mul
53+ elseif ismul (u)
3654 return mapreduce (inds, symdiff, arguments (u))
3755 else
3856 return error (" Variant not supported." )
3957 end
4058end
4159function NamedDimsArrays. dename (a:: LazyNamedDimsArray )
4260 u = unwrap (a)
43- if u isa AbstractNamedDimsArray
61+ if ! iscall (u)
4462 return dename (u)
45- elseif u isa Mul
46- return dename (materialize (a), inds (a))
4763 else
4864 return error (" Variant not supported." )
4965 end
5066end
5167
5268function TermInterface. arguments (a:: LazyNamedDimsArray )
5369 u = unwrap (a)
54- if u isa AbstractNamedDimsArray
70+ if ! iscall (u)
5571 return error (" No arguments." )
56- elseif u isa Mul
72+ elseif ismul (u)
5773 return arguments (u)
5874 else
5975 return error (" Variant not supported." )
@@ -75,24 +91,24 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata
7591 if head ≡ *
7692 return LazyNamedDimsArray (maketerm (Mul, head, args, metadata))
7793 else
78- return error (" Only product terms supported right now." )
94+ return error (" Only mul supported right now." )
7995 end
8096end
8197function TermInterface. operation (a:: LazyNamedDimsArray )
8298 u = unwrap (a)
83- if u isa AbstractNamedDimsArray
99+ if ! iscall (u)
84100 return error (" No operation." )
85- elseif u isa Mul
101+ elseif ismul (u)
86102 return operation (u)
87103 else
88104 return error (" Variant not supported." )
89105 end
90106end
91107function TermInterface. sorted_arguments (a:: LazyNamedDimsArray )
92108 u = unwrap (a)
93- if u isa AbstractNamedDimsArray
109+ if ! iscall (u)
94110 return error (" No arguments." )
95- elseif u isa Mul
111+ elseif ismul (u)
96112 return sorted_arguments (u)
97113 else
98114 return error (" Variant not supported." )
@@ -101,25 +117,75 @@ end
101117function TermInterface. sorted_children (a:: LazyNamedDimsArray )
102118 return sorted_arguments (a)
103119end
120+ ismul (a:: LazyNamedDimsArray ) = ismul (unwrap (a))
121+
122+ function AbstractTrees. children (a:: LazyNamedDimsArray )
123+ if ! iscall (a)
124+ return ()
125+ else
126+ return arguments (a)
127+ end
128+ end
129+ function AbstractTrees. nodevalue (a:: LazyNamedDimsArray )
130+ if ! iscall (a)
131+ return unwrap (a)
132+ else
133+ return operation (a)
134+ end
135+ end
104136
105137using Base. Broadcast: materialize
106138function Base. Broadcast. materialize (a:: LazyNamedDimsArray )
107139 u = unwrap (a)
108- if u isa AbstractNamedDimsArray
140+ if ! iscall (u)
109141 return u
110- elseif u isa Mul
142+ elseif ismul (u)
111143 return mapfoldl (materialize, operation (u), arguments (u))
112144 else
113145 return error (" Variant not supported." )
114146 end
115147end
116148Base. copy (a:: LazyNamedDimsArray ) = materialize (a)
117149
150+ function Base.:(== )(a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
151+ u1, u2 = unwrap .((a1, a2))
152+ if ! iscall (u1) && ! iscall (u2)
153+ return u1 == u2
154+ elseif ismul (u1) && ismul (u2)
155+ return arguments (u1) == arguments (u2)
156+ else
157+ return false
158+ end
159+ end
160+
161+ function printnode (io:: IO , a:: LazyNamedDimsArray )
162+ return printnode (io, unwrap (a))
163+ end
164+ function AbstractTrees. printnode (io:: IO , a:: LazyNamedDimsArray )
165+ return printnode (io, a)
166+ end
167+ function Base. show (io:: IO , a:: LazyNamedDimsArray )
168+ if ! iscall (a)
169+ return show (io, unwrap (a))
170+ else
171+ return printnode (io, a)
172+ end
173+ end
174+ function Base. show (io:: IO , mime:: MIME"text/plain" , a:: LazyNamedDimsArray )
175+ if ! iscall (a)
176+ @invoke show (io, mime, a:: AbstractNamedDimsArray )
177+ return nothing
178+ else
179+ show (io, a)
180+ return nothing
181+ end
182+ end
183+
118184function Base.:* (a:: LazyNamedDimsArray )
119185 u = unwrap (a)
120- if u isa AbstractNamedDimsArray
186+ if ! iscall (u)
121187 return LazyNamedDimsArray (Mul ([lazy (u)]))
122- elseif u isa Mul
188+ elseif ismul (u)
123189 return a
124190 else
125191 return error (" Variant not supported." )
@@ -191,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
191257 return - a
192258end
193259
260+ struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}} } <: AbstractArray{T, N}
261+ name:: Name
262+ axes:: Axes
263+ function SymbolicArray {T} (name, ax:: Tuple{Vararg{AbstractUnitRange{<:Integer}}} ) where {T}
264+ N = length (ax)
265+ return new {T, N, typeof(name), typeof(ax)} (name, ax)
266+ end
267+ end
268+ function SymbolicArray (name, ax:: Tuple{Vararg{AbstractUnitRange{<:Integer}}} )
269+ return SymbolicArray {Any} (name, ax)
270+ end
271+ function SymbolicArray {T} (name, ax:: AbstractUnitRange... ) where {T}
272+ return SymbolicArray {T} (name, ax)
273+ end
274+ function SymbolicArray (name, ax:: AbstractUnitRange... )
275+ return SymbolicArray {Any} (name, ax)
276+ end
277+ symname (a:: SymbolicArray ) = getfield (a, :name )
278+ Base. axes (a:: SymbolicArray ) = getfield (a, :axes )
279+ Base. size (a:: SymbolicArray ) = length .(axes (a))
280+ function Base.:(== )(a:: SymbolicArray , b:: SymbolicArray )
281+ return symname (a) == symname (b) && axes (a) == axes (b)
282+ end
283+ function Base. show (io:: IO , mime:: MIME"text/plain" , a:: SymbolicArray )
284+ Base. summary (io, a)
285+ println (io, " :" )
286+ print (io, repr (symname (a)))
287+ return nothing
288+ end
289+ function Base. show (io:: IO , a:: SymbolicArray )
290+ print (io, " SymbolicArray(" , symname (a), " , " , size (a), " )" )
291+ return nothing
292+ end
293+ using AbstractTrees: AbstractTrees
294+ function AbstractTrees. printnode (io:: IO , a:: SymbolicArray )
295+ print (io, repr (symname (a)))
296+ return nothing
297+ end
298+ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N} , DimNames} =
299+ NamedDimsArray{T, N, Parent, DimNames}
300+ function symnameddims (name)
301+ return lazy (NamedDimsArray (SymbolicArray (name), ()))
302+ end
303+ function printnode (io:: IO , a:: SymbolicNamedDimsArray )
304+ print (io, symname (dename (a)))
305+ if ndims (a) > 0
306+ print (io, " [" , join (dimnames (a), " ," ), " ]" )
307+ end
308+ return nothing
309+ end
310+ function Base.:(== )(a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
311+ return issetequal (inds (a), inds (b)) && dename (a) == dename (b)
312+ end
313+ function Base.:* (a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
314+ return lazy (a) * lazy (b)
315+ end
316+ function Base.:* (a:: SymbolicNamedDimsArray , b:: LazyNamedDimsArray )
317+ return lazy (a) * b
318+ end
319+ function Base.:* (a:: LazyNamedDimsArray , b:: SymbolicNamedDimsArray )
320+ return a * lazy (b)
321+ end
322+
194323end
0 commit comments