@@ -7,52 +7,119 @@ using NamedDimsArrays:
77 AbstractNamedDimsArrayStyle,
88 dename,
99 inds
10+ using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments
1011
11- struct Prod{A}
12- factors:: Vector{A}
13- end
12+ struct Mul{A}
13+ arguments:: Vector{A}
14+ end
15+ TermInterface. arguments (m:: Mul ) = getfield (m, :arguments )
16+ TermInterface. children (m:: Mul ) = arguments (m)
17+ TermInterface. head (m:: Mul ) = operation (m)
18+ TermInterface. iscall (m:: Mul ) = true
19+ TermInterface. isexpr (m:: Mul ) = iscall (m)
20+ TermInterface. maketerm (:: Type{Mul} , head:: typeof (* ), args, metadata) = Mul (args)
21+ TermInterface. operation (m:: Mul ) = *
22+ TermInterface. sorted_arguments (m:: Mul ) = arguments (m)
23+ TermInterface. sorted_children (m:: Mul ) = sorted_arguments (a)
1424
1525@wrapped struct LazyNamedDimsArray{
1626 T, A <: AbstractNamedDimsArray{T} ,
1727 } <: AbstractNamedDimsArray{T, Any}
18- union:: Union{A, Prod {LazyNamedDimsArray{T, A}}}
28+ union:: Union{A, Mul {LazyNamedDimsArray{T, A}}}
1929end
2030
2131function NamedDimsArrays. inds (a:: LazyNamedDimsArray )
22- if unwrap (a) isa AbstractNamedDimsArray
23- return inds (unwrap (a))
24- elseif unwrap (a) isa Prod
25- return mapreduce (inds, symdiff, unwrap (a). factors)
32+ u = unwrap (a)
33+ if u isa AbstractNamedDimsArray
34+ return inds (u)
35+ elseif u isa Mul
36+ return mapreduce (inds, symdiff, arguments (u))
2637 else
2738 return error (" Variant not supported." )
2839 end
2940end
3041function NamedDimsArrays. dename (a:: LazyNamedDimsArray )
31- if unwrap (a) isa AbstractNamedDimsArray
32- return dename (unwrap (a))
33- elseif unwrap (a) isa Prod
42+ u = unwrap (a)
43+ if u isa AbstractNamedDimsArray
44+ return dename (u)
45+ elseif u isa Mul
3446 return dename (materialize (a), inds (a))
3547 else
3648 return error (" Variant not supported." )
3749 end
3850end
3951
52+ function TermInterface. arguments (a:: LazyNamedDimsArray )
53+ u = unwrap (a)
54+ if u isa AbstractNamedDimsArray
55+ return error (" No arguments." )
56+ elseif u isa Mul
57+ return arguments (u)
58+ else
59+ return error (" Variant not supported." )
60+ end
61+ end
62+ function TermInterface. children (a:: LazyNamedDimsArray )
63+ return arguments (a)
64+ end
65+ function TermInterface. head (a:: LazyNamedDimsArray )
66+ return operation (a)
67+ end
68+ function TermInterface. iscall (a:: LazyNamedDimsArray )
69+ return iscall (unwrap (a))
70+ end
71+ function TermInterface. isexpr (a:: LazyNamedDimsArray )
72+ return iscall (a)
73+ end
74+ function TermInterface. maketerm (:: Type{LazyNamedDimsArray} , head, args, metadata)
75+ if head ≡ *
76+ return LazyNamedDimsArray (maketerm (Mul, head, args, metadata))
77+ else
78+ return error (" Only product terms supported right now." )
79+ end
80+ end
81+ function TermInterface. operation (a:: LazyNamedDimsArray )
82+ u = unwrap (a)
83+ if u isa AbstractNamedDimsArray
84+ return error (" No operation." )
85+ elseif u isa Mul
86+ return operation (u)
87+ else
88+ return error (" Variant not supported." )
89+ end
90+ end
91+ function TermInterface. sorted_arguments (a:: LazyNamedDimsArray )
92+ u = unwrap (a)
93+ if u isa AbstractNamedDimsArray
94+ return error (" No arguments." )
95+ elseif u isa Mul
96+ return sorted_arguments (u)
97+ else
98+ return error (" Variant not supported." )
99+ end
100+ end
101+ function TermInterface. sorted_children (a:: LazyNamedDimsArray )
102+ return sorted_arguments (a)
103+ end
104+
40105using Base. Broadcast: materialize
41106function Base. Broadcast. materialize (a:: LazyNamedDimsArray )
42- if unwrap (a) isa AbstractNamedDimsArray
43- return unwrap (a)
44- elseif unwrap (a) isa Prod
45- return prod (materialize, unwrap (a). factors)
107+ u = unwrap (a)
108+ if u isa AbstractNamedDimsArray
109+ return u
110+ elseif u isa Mul
111+ return mapfoldl (materialize, operation (u), arguments (u))
46112 else
47113 return error (" Variant not supported." )
48114 end
49115end
50116Base. copy (a:: LazyNamedDimsArray ) = materialize (a)
51117
52118function Base.:* (a:: LazyNamedDimsArray )
53- if unwrap (a) isa AbstractNamedDimsArray
54- return LazyNamedDimsArray (Prod ([lazy (unwrap (a))]))
55- elseif unwrap (a) isa Prod
119+ u = unwrap (a)
120+ if u isa AbstractNamedDimsArray
121+ return LazyNamedDimsArray (Mul ([lazy (u)]))
122+ elseif u isa Mul
56123 return a
57124 else
58125 return error (" Variant not supported." )
61128
62129function Base.:* (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
63130 # Nested by default.
64- return LazyNamedDimsArray (Prod ([a1, a2]))
131+ return LazyNamedDimsArray (Mul ([a1, a2]))
65132end
66133function Base.:+ (a1:: LazyNamedDimsArray , a2:: LazyNamedDimsArray )
67134 return error (" Not implemented." )
85152function LazyNamedDimsArray (a:: AbstractNamedDimsArray )
86153 return LazyNamedDimsArray {eltype(a), typeof(a)} (a)
87154end
88- function LazyNamedDimsArray (a:: Prod {LazyNamedDimsArray{T, A}} ) where {T, A}
155+ function LazyNamedDimsArray (a:: Mul {LazyNamedDimsArray{T, A}} ) where {T, A}
89156 return LazyNamedDimsArray {T, A} (a)
90157end
91158function lazy (a:: AbstractNamedDimsArray )
@@ -124,59 +191,4 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
124191 return - a
125192end
126193
127- using TermInterface: TermInterface
128- # arguments, arity, children, head, iscall, operation
129- function TermInterface. arguments (a:: LazyNamedDimsArray )
130- if unwrap (a) isa AbstractNamedDimsArray
131- return error (" No arguments." )
132- elseif unwrap (a) isa Prod
133- unwrap (a). factors
134- else
135- return error (" Variant not supported." )
136- end
137- end
138- function TermInterface. children (a:: LazyNamedDimsArray )
139- return TermInterface. arguments (a)
140- end
141- function TermInterface. head (a:: LazyNamedDimsArray )
142- return TermInterface. operation (a)
143- end
144- function TermInterface. iscall (a:: LazyNamedDimsArray )
145- if unwrap (a) isa AbstractNamedDimsArray
146- return false
147- elseif unwrap (a) isa Prod
148- return true
149- else
150- return false
151- end
152- end
153- function TermInterface. isexpr (a:: LazyNamedDimsArray )
154- return TermInterface. iscall (a)
155- end
156- function TermInterface. maketerm (:: Type{LazyNamedDimsArray} , head, args, metadata)
157- if head ≡ prod
158- return LazyNamedDimsArray (Prod (args))
159- else
160- return error (" Only product terms supported right now." )
161- end
162- end
163- function TermInterface. operation (a:: LazyNamedDimsArray )
164- if unwrap (a) isa AbstractNamedDimsArray
165- return error (" No operation." )
166- elseif unwrap (a) isa Prod
167- prod
168- else
169- return error (" Variant not supported." )
170- end
171- end
172- function TermInterface. sorted_arguments (a:: LazyNamedDimsArray )
173- if unwrap (a) isa AbstractNamedDimsArray
174- return error (" No arguments." )
175- elseif unwrap (a) isa Prod
176- return TermInterface. arguments (a)
177- else
178- return error (" Variant not supported." )
179- end
180- end
181-
182194end
0 commit comments