@@ -44,6 +44,13 @@ function printnode_nameddims(io::IO, a::AbstractNamedDimsArray)
4444end
4545
4646# Generic lazy functionality.
47+ function maketerm_lazy (type:: Type , head, args, metadata)
48+ if head ≡ *
49+ return type (maketerm (Mul, head, args, metadata))
50+ else
51+ return error (" Only mul supported right now." )
52+ end
53+ end
4754function getindex_lazy (a:: AbstractArray , I... )
4855 u = unwrap (a)
4956 if ! iscall (u)
@@ -144,7 +151,7 @@ function map_arguments_lazy(f, a)
144151 if ! iscall (u)
145152 return error (" No arguments to map." )
146153 elseif ismul (u)
147- return unspecify_type_parameters ( typeof (a)) (map_arguments (f, u))
154+ return lazy (map_arguments (f, u))
148155 else
149156 return error (" Variant not supported." )
150157 end
@@ -180,68 +187,105 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
180187 return nothing
181188 end
182189end
190+ add_lazy (a1, a2) = error (" Not implemented." )
191+ sub_lazy (a) = error (" Not implemented." )
192+ sub_lazy (a1, a2) = error (" Not implemented." )
193+ function mul_lazy (a)
194+ u = unwrap (a)
195+ if ! iscall (u)
196+ return lazy (Mul ([a]))
197+ elseif ismul (u)
198+ return a
199+ else
200+ return error (" Variant not supported." )
201+ end
202+ end
203+ # Note that this is nested by default.
204+ mul_lazy (a1, a2) = lazy (Mul ([a1, a2]))
205+ mul_lazy (c:: Number , a) = error (" Not implemented." )
206+ mul_lazy (a, c:: Number ) = error (" Not implemented." )
207+ div_lazy (a, c:: Number ) = error (" Not implemented." )
208+
209+ # NamedDimsArrays.jl interface.
210+ function inds_lazy (a)
211+ u = unwrap (a)
212+ if ! iscall (u)
213+ return inds (u)
214+ elseif ismul (u)
215+ return mapreduce (inds, symdiff, arguments (u))
216+ else
217+ return error (" Variant not supported." )
218+ end
219+ end
220+ function dename_lazy (a)
221+ u = unwrap (a)
222+ if ! iscall (u)
223+ return dename (u)
224+ else
225+ return error (" Variant not supported." )
226+ end
227+ end
183228
184229# Lazy broadcasting.
185230struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
186231function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , f, as... )
187232 return error (" Arbitrary broadcasting not supported for LazyNamedDimsArray." )
188233end
189234# Linear operations.
190- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (+ ), a1, a2)
191- return a1 + a2
192- end
193- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (- ), a1, a2)
194- return a1 - a2
195- end
196- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), c:: Number , a)
197- return c * a
198- end
199- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), a, c:: Number )
200- return a * c
201- end
202- # Fix ambiguity error.
203- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), a:: Number , b:: Number )
204- return a * b
235+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (+ ), a1, a2) = a1 + a2
236+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (- ), a1, a2) = a1 - a2
237+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), c:: Number , a) = c * a
238+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), a, c:: Number ) = a * c
239+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (* ), a:: Number , b:: Number ) = a * b
240+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (/ ), a, c:: Number ) = a / c
241+ Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (- ), a) = - a
242+
243+ # Generic functionality for Applied types, like `Mul`, `Add`, etc.
244+ ismul (a) = operation (a) ≡ *
245+ head_applied (a) = operation (a)
246+ iscall_applied (a) = true
247+ isexpr_applied (a) = iscall (a)
248+ function show_applied (io:: IO , a)
249+ args = map (arg -> sprint (AbstractTrees. printnode, arg), arguments (a))
250+ print (io, " (" , join (args, " $(operation (a)) " ), " )" )
251+ return nothing
205252end
206- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (/ ), a, c:: Number )
207- return a / c
253+ sorted_arguments_applied (a) = arguments (a)
254+ children_applied (a) = arguments (a)
255+ sorted_children_applied (a) = sorted_arguments (a)
256+ function maketerm_applied (type, head, args, metadata)
257+ term = type (args)
258+ @assert head ≡ operation (term)
259+ return term
208260end
209- function Broadcast. broadcasted (:: LazyNamedDimsArrayStyle , :: typeof (- ), a)
210- return - a
261+ map_arguments_applied (f, a) = unspecify_type_parameters (typeof (a))(map (f, arguments (a)))
262+ function hash_applied (a, h:: UInt64 )
263+ h = hash (Symbol (unspecify_type_parameters (typeof (a))), h)
264+ for arg in arguments (a)
265+ h = hash (arg, h)
266+ end
267+ return h
211268end
212269
213- # Generic functionality for Applied types, like `Mul`, `Add`, etc.
214270abstract type Applied end
215- TermInterface. head (m:: Applied ) = operation (m)
216- TermInterface. iscall (m:: Applied ) = true
217- TermInterface. isexpr (m:: Applied ) = iscall (m)
218- function Base. show (io:: IO , m:: Applied )
219- args = map (arg -> sprint (AbstractTrees. printnode, arg), arguments (m))
220- print (io, " (" , join (args, " $(operation (m)) " ), " )" )
221- return nothing
222- end
223- TermInterface. sorted_children (m:: Applied ) = sorted_arguments (m)
271+ TermInterface. head (a:: Applied ) = head_applied (a)
272+ TermInterface. iscall (a:: Applied ) = iscall_applied (a)
273+ TermInterface. isexpr (a:: Applied ) = isexpr_applied (a)
274+ Base. show (io:: IO , a:: Applied ) = show_applied (io, a)
275+ TermInterface. sorted_arguments (a:: Applied ) = sorted_arguments_applied (a)
276+ TermInterface. children (a:: Applied ) = children_applied (a)
277+ TermInterface. sorted_children (a:: Applied ) = sorted_children_applied (a)
278+ function TermInterface. maketerm (type:: Type{<:Applied} , head, args, metadata)
279+ return maketerm_applied (type, head, args, metadata)
280+ end
281+ map_arguments (f, a:: Applied ) = map_arguments_applied (f, a)
282+ Base. hash (a:: Applied , h:: UInt64 ) = hash_applied (a, h)
224283
225284struct Mul{A} <: Applied
226285 arguments:: Vector{A}
227286end
228287TermInterface. arguments (m:: Mul ) = getfield (m, :arguments )
229- TermInterface. children (m:: Mul ) = arguments (m)
230- TermInterface. maketerm (:: Type{Mul} , head:: typeof (* ), args, metadata) = Mul (args)
231288TermInterface. operation (m:: Mul ) = *
232- TermInterface. sorted_arguments (m:: Mul ) = arguments (m)
233- ismul (x) = false
234- ismul (m:: Mul ) = true
235- function Base. hash (m:: Mul , h:: UInt64 )
236- h = hash (:Mul , h)
237- for arg in arguments (m)
238- h = hash (arg, h)
239- end
240- return h
241- end
242- function map_arguments (f, m:: Mul )
243- return Mul (map (f, arguments (m)))
244- end
245289
246290@wrapped struct LazyNamedDimsArray{
247291 T, A <: AbstractNamedDimsArray{T} ,
@@ -254,139 +298,53 @@ end
254298function LazyNamedDimsArray (a:: Mul{LazyNamedDimsArray{T, A}} ) where {T, A}
255299 return LazyNamedDimsArray {T, A} (a)
256300end
257- function lazy (a:: AbstractNamedDimsArray )
258- return LazyNamedDimsArray (a)
259- end
260-
261- function NamedDimsArrays. inds (a:: LazyNamedDimsArray )
262- u = unwrap (a)
263- if ! iscall (u)
264- return inds (u)
265- elseif ismul (u)
266- return mapreduce (inds, symdiff, arguments (u))
267- else
268- return error (" Variant not supported." )
269- end
270- end
271- function NamedDimsArrays. dename (a:: LazyNamedDimsArray )
272- u = unwrap (a)
273- if ! iscall (u)
274- return dename (u)
275- else
276- return error (" Variant not supported." )
277- end
278- end
279- function TermInterface. maketerm (:: Type{LazyNamedDimsArray} , head, args, metadata)
280- if head ≡ *
281- return LazyNamedDimsArray (maketerm (Mul, head, args, metadata))
282- else
283- return error (" Only mul supported right now." )
284- end
285- end
286-
287- # Derived functionality.
288- function Base. getindex (a:: LazyNamedDimsArray , I:: Int... )
289- return getindex_lazy (a, I... )
290- end
291- function TermInterface. arguments (a:: LazyNamedDimsArray )
292- return arguments_lazy (a)
293- end
294- function TermInterface. children (a:: LazyNamedDimsArray )
295- return children_lazy (a)
296- end
297- function TermInterface. head (a:: LazyNamedDimsArray )
298- return head_lazy (a)
299- end
300- function TermInterface. iscall (a:: LazyNamedDimsArray )
301- return iscall_lazy (a)
302- end
303- function TermInterface. isexpr (a:: LazyNamedDimsArray )
304- return isexpr_lazy (a)
305- end
306- function TermInterface. operation (a:: LazyNamedDimsArray )
307- return operation_lazy (a)
308- end
309- function TermInterface. sorted_arguments (a:: LazyNamedDimsArray )
310- return sorted_arguments_lazy (a)
311- end
312- function AbstractTrees. children (a:: LazyNamedDimsArray )
313- return abstracttrees_children_lazy (a)
314- end
315- function TermInterface. sorted_children (a:: LazyNamedDimsArray )
316- return sorted_children_lazy (a)
317- end
318- ismul (a:: LazyNamedDimsArray ) = ismul_lazy (a)
319- function AbstractTrees. nodevalue (a:: LazyNamedDimsArray )
320- return nodevalue_lazy (a)
321- end
322- function Base. Broadcast. materialize (a:: LazyNamedDimsArray )
323- return materialize_lazy (a)
324- end
325- Base. copy (a:: LazyNamedDimsArray ) = copy_lazy (a)
326- function Base.:(== )(a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
327- return equals_lazy (a1, a2)
328- end
329- function Base. hash (a:: LazyNamedDimsArray , h:: UInt64 )
330- return hash_lazy (a, h)
331- end
332- function map_arguments (f, a:: LazyNamedDimsArray )
333- return map_arguments_lazy (f, a)
334- end
335- function substitute (a:: LazyNamedDimsArray , substitutions)
336- return substitute_lazy (a, substitutions)
337- end
338- function AbstractTrees. printnode (io:: IO , a:: LazyNamedDimsArray )
339- return printnode_lazy (io, a)
340- end
341- function printnode_nameddims (io:: IO , a:: LazyNamedDimsArray )
342- return printnode_lazy (io, a)
343- end
344- function Base. show (io:: IO , a:: LazyNamedDimsArray )
345- return show_lazy (io, a)
346- end
347- function Base. show (io:: IO , mime:: MIME"text/plain" , a:: LazyNamedDimsArray )
348- return show_lazy (io, mime, a)
349- end
350-
351- function Base.:* (a:: LazyNamedDimsArray )
352- u = unwrap (a)
353- if ! iscall (u)
354- return LazyNamedDimsArray (Mul ([lazy (u)]))
355- elseif ismul (u)
356- return a
357- else
358- return error (" Variant not supported." )
359- end
360- end
301+ lazy (a:: LazyNamedDimsArray ) = a
302+ lazy (a:: AbstractNamedDimsArray ) = LazyNamedDimsArray (a)
303+ lazy (a:: Mul{<:LazyNamedDimsArray} ) = LazyNamedDimsArray (a)
361304
362- function Base.:* (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
363- # Nested by default.
364- return LazyNamedDimsArray (Mul ([a1, a2]))
365- end
366- function Base.:+ (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
367- return error (" Not implemented." )
368- end
369- function Base.:- (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
370- return error (" Not implemented." )
371- end
372- function Base.:* (c:: Number , a:: LazyNamedDimsArray )
373- return error (" Not implemented." )
374- end
375- function Base.:* (a:: LazyNamedDimsArray , c:: Number )
376- return error (" Not implemented." )
377- end
378- function Base.:/ (a:: LazyNamedDimsArray , c:: Number )
379- return error (" Not implemented." )
380- end
381- function Base.:- (a:: LazyNamedDimsArray )
382- return error (" Not implemented." )
383- end
305+ NamedDimsArrays. inds (a:: LazyNamedDimsArray ) = inds_lazy (a)
306+ NamedDimsArrays. dename (a:: LazyNamedDimsArray ) = dename_lazy (a)
384307
385308# Broadcasting
386309function Base. BroadcastStyle (:: Type{<:LazyNamedDimsArray} )
387310 return LazyNamedDimsArrayStyle ()
388311end
389312
313+ # Derived functionality.
314+ function TermInterface. maketerm (type:: Type{LazyNamedDimsArray} , head, args, metadata)
315+ return maketerm_lazy (type, head, args, metadata)
316+ end
317+ Base. getindex (a:: LazyNamedDimsArray , I:: Int... ) = getindex_lazy (a, I... )
318+ TermInterface. arguments (a:: LazyNamedDimsArray ) = arguments_lazy (a)
319+ TermInterface. children (a:: LazyNamedDimsArray ) = children_lazy (a)
320+ TermInterface. head (a:: LazyNamedDimsArray ) = head_lazy (a)
321+ TermInterface. iscall (a:: LazyNamedDimsArray ) = iscall_lazy (a)
322+ TermInterface. isexpr (a:: LazyNamedDimsArray ) = isexpr_lazy (a)
323+ TermInterface. operation (a:: LazyNamedDimsArray ) = operation_lazy (a)
324+ TermInterface. sorted_arguments (a:: LazyNamedDimsArray ) = sorted_arguments_lazy (a)
325+ AbstractTrees. children (a:: LazyNamedDimsArray ) = abstracttrees_children_lazy (a)
326+ TermInterface. sorted_children (a:: LazyNamedDimsArray ) = sorted_children_lazy (a)
327+ ismul (a:: LazyNamedDimsArray ) = ismul_lazy (a)
328+ AbstractTrees. nodevalue (a:: LazyNamedDimsArray ) = nodevalue_lazy (a)
329+ Base. Broadcast. materialize (a:: LazyNamedDimsArray ) = materialize_lazy (a)
330+ Base. copy (a:: LazyNamedDimsArray ) = copy_lazy (a)
331+ Base.:(== )(a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray ) = equals_lazy (a1, a2)
332+ Base. hash (a:: LazyNamedDimsArray , h:: UInt64 ) = hash_lazy (a, h)
333+ map_arguments (f, a:: LazyNamedDimsArray ) = map_arguments_lazy (f, a)
334+ substitute (a:: LazyNamedDimsArray , substitutions) = substitute_lazy (a, substitutions)
335+ AbstractTrees. printnode (io:: IO , a:: LazyNamedDimsArray ) = printnode_lazy (io, a)
336+ printnode_nameddims (io:: IO , a:: LazyNamedDimsArray ) = printnode_lazy (io, a)
337+ Base. show (io:: IO , a:: LazyNamedDimsArray ) = show_lazy (io, a)
338+ Base. show (io:: IO , mime:: MIME"text/plain" , a:: LazyNamedDimsArray ) = show_lazy (io, mime, a)
339+ Base.:* (a:: LazyNamedDimsArray ) = mul_lazy (a)
340+ Base.:* (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray ) = mul_lazy (a1, a2)
341+ Base.:+ (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray ) = add_lazy (a1, a2)
342+ Base.:- (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray ) = sub_lazy (a1, a2)
343+ Base.:* (c:: Number , a:: LazyNamedDimsArray ) = mul_lazy (c, a)
344+ Base.:* (a:: LazyNamedDimsArray , c:: Number ) = mul_lazy (a, c)
345+ Base.:/ (a:: LazyNamedDimsArray , c:: Number ) = div_lazy (a, c)
346+ Base.:- (a:: LazyNamedDimsArray ) = sub_lazy (a)
347+
390348struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}} } <: AbstractArray{T, N}
391349 name:: Name
392350 axes:: Axes
@@ -448,20 +406,12 @@ function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
448406 end
449407 return nothing
450408end
451- function printnode_nameddims (io:: IO , a:: SymbolicNamedDimsArray )
452- return AbstractTrees. printnode (io, a)
453- end
409+ printnode_nameddims (io:: IO , a:: SymbolicNamedDimsArray ) = AbstractTrees. printnode (io, a)
454410function Base.:(== )(a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
455411 return issetequal (inds (a), inds (b)) && dename (a) == dename (b)
456412end
457- function Base.:* (a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray )
458- return lazy (a) * lazy (b)
459- end
460- function Base.:* (a:: SymbolicNamedDimsArray , b:: LazyNamedDimsArray )
461- return lazy (a) * b
462- end
463- function Base.:* (a:: LazyNamedDimsArray , b:: SymbolicNamedDimsArray )
464- return a * lazy (b)
465- end
413+ Base.:* (a:: SymbolicNamedDimsArray , b:: SymbolicNamedDimsArray ) = lazy (a) * lazy (b)
414+ Base.:* (a:: SymbolicNamedDimsArray , b:: LazyNamedDimsArray ) = lazy (a) * b
415+ Base.:* (a:: LazyNamedDimsArray , b:: SymbolicNamedDimsArray ) = a * lazy (b)
466416
467417end
0 commit comments