Skip to content

Commit 6ddfdf1

Browse files
committed
More reorg
1 parent 96a9ebf commit 6ddfdf1

File tree

1 file changed

+133
-183
lines changed

1 file changed

+133
-183
lines changed

src/lazynameddimsarrays.jl

Lines changed: 133 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ function printnode_nameddims(io::IO, a::AbstractNamedDimsArray)
4444
end
4545

4646
# Generic lazy functionality.
47+
function maketerm_lazy(type::Type, head, args, metadata)
48+
if head *
49+
return type(maketerm(Mul, head, args, metadata))
50+
else
51+
return error("Only mul supported right now.")
52+
end
53+
end
4754
function getindex_lazy(a::AbstractArray, I...)
4855
u = unwrap(a)
4956
if !iscall(u)
@@ -144,7 +151,7 @@ function map_arguments_lazy(f, a)
144151
if !iscall(u)
145152
return error("No arguments to map.")
146153
elseif ismul(u)
147-
return unspecify_type_parameters(typeof(a))(map_arguments(f, u))
154+
return lazy(map_arguments(f, u))
148155
else
149156
return error("Variant not supported.")
150157
end
@@ -180,68 +187,105 @@ function show_lazy(io::IO, mime::MIME"text/plain", a)
180187
return nothing
181188
end
182189
end
190+
add_lazy(a1, a2) = error("Not implemented.")
191+
sub_lazy(a) = error("Not implemented.")
192+
sub_lazy(a1, a2) = error("Not implemented.")
193+
function mul_lazy(a)
194+
u = unwrap(a)
195+
if !iscall(u)
196+
return lazy(Mul([a]))
197+
elseif ismul(u)
198+
return a
199+
else
200+
return error("Variant not supported.")
201+
end
202+
end
203+
# Note that this is nested by default.
204+
mul_lazy(a1, a2) = lazy(Mul([a1, a2]))
205+
mul_lazy(c::Number, a) = error("Not implemented.")
206+
mul_lazy(a, c::Number) = error("Not implemented.")
207+
div_lazy(a, c::Number) = error("Not implemented.")
208+
209+
# NamedDimsArrays.jl interface.
210+
function inds_lazy(a)
211+
u = unwrap(a)
212+
if !iscall(u)
213+
return inds(u)
214+
elseif ismul(u)
215+
return mapreduce(inds, symdiff, arguments(u))
216+
else
217+
return error("Variant not supported.")
218+
end
219+
end
220+
function dename_lazy(a)
221+
u = unwrap(a)
222+
if !iscall(u)
223+
return dename(u)
224+
else
225+
return error("Variant not supported.")
226+
end
227+
end
183228

184229
# Lazy broadcasting.
185230
struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end
186231
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...)
187232
return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.")
188233
end
189234
# Linear operations.
190-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2)
191-
return a1 + a2
192-
end
193-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2)
194-
return a1 - a2
195-
end
196-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a)
197-
return c * a
198-
end
199-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number)
200-
return a * c
201-
end
202-
# Fix ambiguity error.
203-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number)
204-
return a * b
235+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) = a1 + a2
236+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) = a1 - a2
237+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) = c * a
238+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) = a * c
239+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) = a * b
240+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) = a / c
241+
Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) = -a
242+
243+
# Generic functionality for Applied types, like `Mul`, `Add`, etc.
244+
ismul(a) = operation(a) *
245+
head_applied(a) = operation(a)
246+
iscall_applied(a) = true
247+
isexpr_applied(a) = iscall(a)
248+
function show_applied(io::IO, a)
249+
args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a))
250+
print(io, "(", join(args, " $(operation(a)) "), ")")
251+
return nothing
205252
end
206-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number)
207-
return a / c
253+
sorted_arguments_applied(a) = arguments(a)
254+
children_applied(a) = arguments(a)
255+
sorted_children_applied(a) = sorted_arguments(a)
256+
function maketerm_applied(type, head, args, metadata)
257+
term = type(args)
258+
@assert head operation(term)
259+
return term
208260
end
209-
function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a)
210-
return -a
261+
map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a)))
262+
function hash_applied(a, h::UInt64)
263+
h = hash(Symbol(unspecify_type_parameters(typeof(a))), h)
264+
for arg in arguments(a)
265+
h = hash(arg, h)
266+
end
267+
return h
211268
end
212269

213-
# Generic functionality for Applied types, like `Mul`, `Add`, etc.
214270
abstract type Applied end
215-
TermInterface.head(m::Applied) = operation(m)
216-
TermInterface.iscall(m::Applied) = true
217-
TermInterface.isexpr(m::Applied) = iscall(m)
218-
function Base.show(io::IO, m::Applied)
219-
args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(m))
220-
print(io, "(", join(args, " $(operation(m)) "), ")")
221-
return nothing
222-
end
223-
TermInterface.sorted_children(m::Applied) = sorted_arguments(m)
271+
TermInterface.head(a::Applied) = head_applied(a)
272+
TermInterface.iscall(a::Applied) = iscall_applied(a)
273+
TermInterface.isexpr(a::Applied) = isexpr_applied(a)
274+
Base.show(io::IO, a::Applied) = show_applied(io, a)
275+
TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a)
276+
TermInterface.children(a::Applied) = children_applied(a)
277+
TermInterface.sorted_children(a::Applied) = sorted_children_applied(a)
278+
function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata)
279+
return maketerm_applied(type, head, args, metadata)
280+
end
281+
map_arguments(f, a::Applied) = map_arguments_applied(f, a)
282+
Base.hash(a::Applied, h::UInt64) = hash_applied(a, h)
224283

225284
struct Mul{A} <: Applied
226285
arguments::Vector{A}
227286
end
228287
TermInterface.arguments(m::Mul) = getfield(m, :arguments)
229-
TermInterface.children(m::Mul) = arguments(m)
230-
TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args)
231288
TermInterface.operation(m::Mul) = *
232-
TermInterface.sorted_arguments(m::Mul) = arguments(m)
233-
ismul(x) = false
234-
ismul(m::Mul) = true
235-
function Base.hash(m::Mul, h::UInt64)
236-
h = hash(:Mul, h)
237-
for arg in arguments(m)
238-
h = hash(arg, h)
239-
end
240-
return h
241-
end
242-
function map_arguments(f, m::Mul)
243-
return Mul(map(f, arguments(m)))
244-
end
245289

246290
@wrapped struct LazyNamedDimsArray{
247291
T, A <: AbstractNamedDimsArray{T},
@@ -254,139 +298,53 @@ end
254298
function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A}
255299
return LazyNamedDimsArray{T, A}(a)
256300
end
257-
function lazy(a::AbstractNamedDimsArray)
258-
return LazyNamedDimsArray(a)
259-
end
260-
261-
function NamedDimsArrays.inds(a::LazyNamedDimsArray)
262-
u = unwrap(a)
263-
if !iscall(u)
264-
return inds(u)
265-
elseif ismul(u)
266-
return mapreduce(inds, symdiff, arguments(u))
267-
else
268-
return error("Variant not supported.")
269-
end
270-
end
271-
function NamedDimsArrays.dename(a::LazyNamedDimsArray)
272-
u = unwrap(a)
273-
if !iscall(u)
274-
return dename(u)
275-
else
276-
return error("Variant not supported.")
277-
end
278-
end
279-
function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata)
280-
if head *
281-
return LazyNamedDimsArray(maketerm(Mul, head, args, metadata))
282-
else
283-
return error("Only mul supported right now.")
284-
end
285-
end
286-
287-
# Derived functionality.
288-
function Base.getindex(a::LazyNamedDimsArray, I::Int...)
289-
return getindex_lazy(a, I...)
290-
end
291-
function TermInterface.arguments(a::LazyNamedDimsArray)
292-
return arguments_lazy(a)
293-
end
294-
function TermInterface.children(a::LazyNamedDimsArray)
295-
return children_lazy(a)
296-
end
297-
function TermInterface.head(a::LazyNamedDimsArray)
298-
return head_lazy(a)
299-
end
300-
function TermInterface.iscall(a::LazyNamedDimsArray)
301-
return iscall_lazy(a)
302-
end
303-
function TermInterface.isexpr(a::LazyNamedDimsArray)
304-
return isexpr_lazy(a)
305-
end
306-
function TermInterface.operation(a::LazyNamedDimsArray)
307-
return operation_lazy(a)
308-
end
309-
function TermInterface.sorted_arguments(a::LazyNamedDimsArray)
310-
return sorted_arguments_lazy(a)
311-
end
312-
function AbstractTrees.children(a::LazyNamedDimsArray)
313-
return abstracttrees_children_lazy(a)
314-
end
315-
function TermInterface.sorted_children(a::LazyNamedDimsArray)
316-
return sorted_children_lazy(a)
317-
end
318-
ismul(a::LazyNamedDimsArray) = ismul_lazy(a)
319-
function AbstractTrees.nodevalue(a::LazyNamedDimsArray)
320-
return nodevalue_lazy(a)
321-
end
322-
function Base.Broadcast.materialize(a::LazyNamedDimsArray)
323-
return materialize_lazy(a)
324-
end
325-
Base.copy(a::LazyNamedDimsArray) = copy_lazy(a)
326-
function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
327-
return equals_lazy(a1, a2)
328-
end
329-
function Base.hash(a::LazyNamedDimsArray, h::UInt64)
330-
return hash_lazy(a, h)
331-
end
332-
function map_arguments(f, a::LazyNamedDimsArray)
333-
return map_arguments_lazy(f, a)
334-
end
335-
function substitute(a::LazyNamedDimsArray, substitutions)
336-
return substitute_lazy(a, substitutions)
337-
end
338-
function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray)
339-
return printnode_lazy(io, a)
340-
end
341-
function printnode_nameddims(io::IO, a::LazyNamedDimsArray)
342-
return printnode_lazy(io, a)
343-
end
344-
function Base.show(io::IO, a::LazyNamedDimsArray)
345-
return show_lazy(io, a)
346-
end
347-
function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray)
348-
return show_lazy(io, mime, a)
349-
end
350-
351-
function Base.:*(a::LazyNamedDimsArray)
352-
u = unwrap(a)
353-
if !iscall(u)
354-
return LazyNamedDimsArray(Mul([lazy(u)]))
355-
elseif ismul(u)
356-
return a
357-
else
358-
return error("Variant not supported.")
359-
end
360-
end
301+
lazy(a::LazyNamedDimsArray) = a
302+
lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a)
303+
lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a)
361304

362-
function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
363-
# Nested by default.
364-
return LazyNamedDimsArray(Mul([a1, a2]))
365-
end
366-
function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
367-
return error("Not implemented.")
368-
end
369-
function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray)
370-
return error("Not implemented.")
371-
end
372-
function Base.:*(c::Number, a::LazyNamedDimsArray)
373-
return error("Not implemented.")
374-
end
375-
function Base.:*(a::LazyNamedDimsArray, c::Number)
376-
return error("Not implemented.")
377-
end
378-
function Base.:/(a::LazyNamedDimsArray, c::Number)
379-
return error("Not implemented.")
380-
end
381-
function Base.:-(a::LazyNamedDimsArray)
382-
return error("Not implemented.")
383-
end
305+
NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a)
306+
NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a)
384307

385308
# Broadcasting
386309
function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray})
387310
return LazyNamedDimsArrayStyle()
388311
end
389312

313+
# Derived functionality.
314+
function TermInterface.maketerm(type::Type{LazyNamedDimsArray}, head, args, metadata)
315+
return maketerm_lazy(type, head, args, metadata)
316+
end
317+
Base.getindex(a::LazyNamedDimsArray, I::Int...) = getindex_lazy(a, I...)
318+
TermInterface.arguments(a::LazyNamedDimsArray) = arguments_lazy(a)
319+
TermInterface.children(a::LazyNamedDimsArray) = children_lazy(a)
320+
TermInterface.head(a::LazyNamedDimsArray) = head_lazy(a)
321+
TermInterface.iscall(a::LazyNamedDimsArray) = iscall_lazy(a)
322+
TermInterface.isexpr(a::LazyNamedDimsArray) = isexpr_lazy(a)
323+
TermInterface.operation(a::LazyNamedDimsArray) = operation_lazy(a)
324+
TermInterface.sorted_arguments(a::LazyNamedDimsArray) = sorted_arguments_lazy(a)
325+
AbstractTrees.children(a::LazyNamedDimsArray) = abstracttrees_children_lazy(a)
326+
TermInterface.sorted_children(a::LazyNamedDimsArray) = sorted_children_lazy(a)
327+
ismul(a::LazyNamedDimsArray) = ismul_lazy(a)
328+
AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a)
329+
Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a)
330+
Base.copy(a::LazyNamedDimsArray) = copy_lazy(a)
331+
Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2)
332+
Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h)
333+
map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a)
334+
substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions)
335+
AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a)
336+
printnode_nameddims(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a)
337+
Base.show(io::IO, a::LazyNamedDimsArray) = show_lazy(io, a)
338+
Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) = show_lazy(io, mime, a)
339+
Base.:*(a::LazyNamedDimsArray) = mul_lazy(a)
340+
Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = mul_lazy(a1, a2)
341+
Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = add_lazy(a1, a2)
342+
Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = sub_lazy(a1, a2)
343+
Base.:*(c::Number, a::LazyNamedDimsArray) = mul_lazy(c, a)
344+
Base.:*(a::LazyNamedDimsArray, c::Number) = mul_lazy(a, c)
345+
Base.:/(a::LazyNamedDimsArray, c::Number) = div_lazy(a, c)
346+
Base.:-(a::LazyNamedDimsArray) = sub_lazy(a)
347+
390348
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
391349
name::Name
392350
axes::Axes
@@ -448,20 +406,12 @@ function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
448406
end
449407
return nothing
450408
end
451-
function printnode_nameddims(io::IO, a::SymbolicNamedDimsArray)
452-
return AbstractTrees.printnode(io, a)
453-
end
409+
printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a)
454410
function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
455411
return issetequal(inds(a), inds(b)) && dename(a) == dename(b)
456412
end
457-
function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray)
458-
return lazy(a) * lazy(b)
459-
end
460-
function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray)
461-
return lazy(a) * b
462-
end
463-
function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray)
464-
return a * lazy(b)
465-
end
413+
Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b)
414+
Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b
415+
Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b)
466416

467417
end

0 commit comments

Comments
 (0)