11module LazyNamedDimsArrays
22
3+ using AbstractTrees: AbstractTrees
34using WrappedUnions: @wrapped , unwrap
45using NamedDimsArrays:
56 NamedDimsArrays,
67 AbstractNamedDimsArray,
78 AbstractNamedDimsArrayStyle,
9+ NamedDimsArray,
810 dename,
11+ dimnames,
912 inds
13+ using .. SymbolicArrays: SymbolicArrays, SymbolicArray
1014using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1115
16+ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N} , DimNames} =
17+ NamedDimsArray{T, N, Parent, DimNames}
18+ function symnameddims (name)
19+ return lazy (NamedDimsArray (SymbolicArray (name), ()))
20+ end
21+ function printnode (io:: IO , a:: SymbolicNamedDimsArray )
22+ print (io, SymbolicArrays. name (dename (a)))
23+ print (io, " [" , join (dimnames (a), " ," ), " ]" )
24+ return nothing
25+ end
26+ function Base.:(== )(a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
27+ return issetequal (inds (a), inds (b)) && dename (a) == dename (b)
28+ end
29+ function Base.:* (a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
30+ return lazy (a) * lazy (b)
31+ end
32+ function Base.:* (a:: SymbolicNamedDimsArray , b:: LazyNamedDimsArray )
33+ return lazy (a) * b
34+ end
35+ function Base.:* (a:: LazyNamedDimsArray , b:: SymbolicNamedDimsArray )
36+ return a * lazy (b)
37+ end
38+
39+ # Custom version of `AbstractTrees.printnode` to
40+ # avoid type piracy when overloading on `AbstractNamedDimsArray`.
41+ printnode (io:: IO , x) = AbstractTrees. printnode (io, x)
42+ function printnode (io:: IO , a:: AbstractNamedDimsArray )
43+ show (io, collect (dimnames (a)))
44+ return nothing
45+ end
46+
1247struct Mul{A}
1348 arguments:: Vector{A}
1449end
@@ -21,6 +56,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
2156TermInterface. operation (m:: Mul ) = *
2257TermInterface. sorted_arguments (m:: Mul ) = arguments (m)
2358TermInterface. sorted_children (m:: Mul ) = sorted_arguments (a)
59+ ismul (x) = false
60+ ismul (m:: Mul ) = true
61+ function Base. show (io:: IO , m:: Mul )
62+ args = map (arg -> sprint (printnode, arg), arguments (m))
63+ print (io, " (" , join (args, " $(operation (m)) " ), " )" )
64+ return nothing
65+ end
2466
2567@wrapped struct LazyNamedDimsArray{
2668 T, A <: AbstractNamedDimsArray{T} ,
3072
3173function NamedDimsArrays. inds (a:: LazyNamedDimsArray )
3274 u = unwrap (a)
33- if u isa AbstractNamedDimsArray
75+ if ! iscall (u)
3476 return inds (u)
35- elseif u isa Mul
77+ elseif ismul (u)
3678 return mapreduce (inds, symdiff, arguments (u))
3779 else
3880 return error (" Variant not supported." )
3981 end
4082end
4183function NamedDimsArrays. dename (a:: LazyNamedDimsArray )
4284 u = unwrap (a)
43- if u isa AbstractNamedDimsArray
85+ if ! iscall (u)
4486 return dename (u)
45- elseif u isa Mul
46- return dename (materialize (a), inds (a))
4787 else
4888 return error (" Variant not supported." )
4989 end
5090end
5191
5292function TermInterface. arguments (a:: LazyNamedDimsArray )
5393 u = unwrap (a)
54- if u isa AbstractNamedDimsArray
94+ if ! iscall (u)
5595 return error (" No arguments." )
56- elseif u isa Mul
96+ elseif ismul (u)
5797 return arguments (u)
5898 else
5999 return error (" Variant not supported." )
@@ -75,24 +115,24 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata
75115 if head ≡ *
76116 return LazyNamedDimsArray (maketerm (Mul, head, args, metadata))
77117 else
78- return error (" Only product terms supported right now." )
118+ return error (" Only mul supported right now." )
79119 end
80120end
81121function TermInterface. operation (a:: LazyNamedDimsArray )
82122 u = unwrap (a)
83- if u isa AbstractNamedDimsArray
123+ if ! iscall (u)
84124 return error (" No operation." )
85- elseif u isa Mul
125+ elseif ismul (u)
86126 return operation (u)
87127 else
88128 return error (" Variant not supported." )
89129 end
90130end
91131function TermInterface. sorted_arguments (a:: LazyNamedDimsArray )
92132 u = unwrap (a)
93- if u isa AbstractNamedDimsArray
133+ if ! iscall (u)
94134 return error (" No arguments." )
95- elseif u isa Mul
135+ elseif ismul (u)
96136 return sorted_arguments (u)
97137 else
98138 return error (" Variant not supported." )
@@ -101,25 +141,75 @@ end
101141function TermInterface. sorted_children (a:: LazyNamedDimsArray )
102142 return sorted_arguments (a)
103143end
144+ ismul (a:: LazyNamedDimsArray ) = ismul (unwrap (a))
145+
146+ function AbstractTrees. children (a:: LazyNamedDimsArray )
147+ if ! iscall (a)
148+ return ()
149+ else
150+ return arguments (a)
151+ end
152+ end
153+ function AbstractTrees. nodevalue (a:: LazyNamedDimsArray )
154+ if ! iscall (a)
155+ return unwrap (a)
156+ else
157+ return operation (a)
158+ end
159+ end
104160
105161using Base. Broadcast: materialize
106162function Base. Broadcast. materialize (a:: LazyNamedDimsArray )
107163 u = unwrap (a)
108- if u isa AbstractNamedDimsArray
164+ if ! iscall (u)
109165 return u
110- elseif u isa Mul
166+ elseif ismul (u)
111167 return mapfoldl (materialize, operation (u), arguments (u))
112168 else
113169 return error (" Variant not supported." )
114170 end
115171end
116172Base. copy (a:: LazyNamedDimsArray ) = materialize (a)
117173
174+ function Base.:(== )(a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
175+ u1, u2 = unwrap .((a1, a2))
176+ if ! iscall (u1) && ! iscall (u2)
177+ return u1 == u2
178+ elseif ismul (u1) && ismul (u2)
179+ return arguments (u1) == arguments (u2)
180+ else
181+ return false
182+ end
183+ end
184+
185+ function printnode (io:: IO , a:: LazyNamedDimsArray )
186+ return printnode (io, unwrap (a))
187+ end
188+ function AbstractTrees. printnode (io:: IO , a:: LazyNamedDimsArray )
189+ return printnode (io, a)
190+ end
191+ function Base. show (io:: IO , a:: LazyNamedDimsArray )
192+ if ! iscall (a)
193+ return show (io, unwrap (a))
194+ else
195+ return printnode (io, a)
196+ end
197+ end
198+ function Base. show (io:: IO , mime:: MIME"text/plain" , a:: LazyNamedDimsArray )
199+ if ! iscall (a)
200+ @invoke show (io, mime, a:: AbstractNamedDimsArray )
201+ return nothing
202+ else
203+ show (io, a)
204+ return nothing
205+ end
206+ end
207+
118208function Base.:* (a:: LazyNamedDimsArray )
119209 u = unwrap (a)
120- if u isa AbstractNamedDimsArray
210+ if ! iscall (u)
121211 return LazyNamedDimsArray (Mul ([lazy (u)]))
122- elseif u isa Mul
212+ elseif ismul (u)
123213 return a
124214 else
125215 return error (" Variant not supported." )
0 commit comments