Skip to content

Commit d932692

Browse files
authored
Add lazy Summed (#17)
1 parent 1aa087f commit d932692

File tree

6 files changed

+352
-190
lines changed

6 files changed

+352
-190
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MapBroadcast"
22
uuid = "ebd9b9da-f48d-417c-9660-449667d60261"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.9"
4+
version = "0.1.10"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

src/MapBroadcast.jl

Lines changed: 2 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,187 +1,6 @@
11
module MapBroadcast
2-
# Convert broadcast call to map call by capturing array arguments
3-
# with `map_args` and creating a map function with `map_function`.
4-
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.
52

6-
using Base.Broadcast:
7-
Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate
8-
using BlockArrays: mortar
9-
using Compat: allequal
10-
using FillArrays: Fill
11-
12-
const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}
13-
14-
# Get the arguments of the map expression that
15-
# is equivalent to the broadcast expression.
16-
function map_args(bc::Broadcasted)
17-
return map_args_flatten(bc)
18-
end
19-
20-
function map_args_flatten(bc::Broadcasted, args_rest...)
21-
return (map_args_flatten(bc.args...)..., map_args_flatten(args_rest...)...)
22-
end
23-
function map_args_flatten(arg1::AbstractArray, args_rest...)
24-
return (arg1, map_args_flatten(args_rest...)...)
25-
end
26-
map_args_flatten(arg1, args_rest...) = map_args_flatten(args_rest...)
27-
map_args_flatten() = ()
28-
29-
struct MapFunction{F,Args<:Tuple} <: Function
30-
f::F
31-
args::Args
32-
end
33-
struct Arg end
34-
35-
# Get the function of the map expression that
36-
# is equivalent to the broadcast expression.
37-
# Returns a `MapFunction`.
38-
function map_function(bc::Broadcasted)
39-
return map_function_arg(bc)
40-
end
41-
map_function_args(args::Tuple{}) = args
42-
function map_function_args(args::Tuple)
43-
return (map_function_arg(args[1]), map_function_args(Base.tail(args))...)
44-
end
45-
function map_function_arg(bc::Broadcasted)
46-
return MapFunction(bc.f, map_function_args(bc.args))
47-
end
48-
map_function_arg(a::WrappedScalarArgs) = a[]
49-
map_function_arg(a::AbstractArray) = Arg()
50-
map_function_arg(a) = a
51-
52-
# Evaluate MapFunction
53-
(f::MapFunction)(args...) = apply_arg(f, args)[1]
54-
function apply_arg(f::MapFunction, args)
55-
mapfunction_args, args′ = apply_args(f.args, args)
56-
return f.f(mapfunction_args...), args′
57-
end
58-
apply_arg(mapfunction_arg::Arg, args) = args[1], Base.tail(args)
59-
apply_arg(mapfunction_arg, args) = mapfunction_arg, args
60-
function apply_args(mapfunction_args::Tuple, args)
61-
mapfunction_args1, args′ = apply_arg(mapfunction_args[1], args)
62-
mapfunction_args_rest, args′′ = apply_args(Base.tail(mapfunction_args), args′)
63-
return (mapfunction_args1, mapfunction_args_rest...), args′′
64-
end
65-
apply_args(mapfunction_args::Tuple{}, args) = mapfunction_args, args
66-
67-
is_map_expr_or_arg(arg::AbstractArray) = true
68-
is_map_expr_or_arg(arg::Any) = false
69-
function is_map_expr_or_arg(bc::Broadcasted)
70-
return all(is_map_expr_or_arg, bc.args)
71-
end
72-
function is_map_expr(bc::Broadcasted)
73-
return is_map_expr_or_arg(bc)
74-
end
75-
76-
abstract type ExprStyle end
77-
struct MapExpr <: ExprStyle end
78-
struct NotMapExpr <: ExprStyle end
79-
80-
ExprStyle(bc::Broadcasted) = is_map_expr(bc) ? MapExpr() : NotMapExpr()
81-
82-
abstract type AbstractMapped <: Base.AbstractBroadcasted end
83-
84-
function check_shape(::Type{Bool}, args...)
85-
return allequal(axes, args)
86-
end
87-
function check_shape(args...)
88-
if !check_shape(Bool, args...)
89-
throw(DimensionMismatch("Mismatched shapes $(axes.(args))."))
90-
end
91-
return nothing
92-
end
93-
94-
# Promote the shape of the arguments to support broadcasting
95-
# over dimensions by expanding singleton dimensions.
96-
function promote_shape(ax, f, args::AbstractArray...)
97-
if allequal((ax, axes.(args)...))
98-
return f, args
99-
end
100-
return f, promote_shape_tile(ax, args...)
101-
end
102-
function promote_shape_tile(common_axes, args::AbstractArray...)
103-
return map(arg -> tile(arg, common_axes), args)
104-
end
105-
106-
# Catch the case of zero arguments, like `a .= 2`.
107-
function promote_shape(ax, f)
108-
return identity, (Fill(f(), ax),)
109-
end
110-
111-
# Extend by repeating value up to length.
112-
function extend(t::Tuple, value, length)
113-
return ntuple(i -> get(t, i, value), length)
114-
end
115-
116-
# Handles logic of expanding singleton dimensions
117-
# to match an array shape in broadcasting
118-
# by tiling or repeating the input array.
119-
function tile(a::AbstractArray, ax)
120-
axes(a) == ax && return a
121-
# Must be one-based for now.
122-
@assert all(isone, first.(ax))
123-
@assert all(isone, first.(axes(a)))
124-
ndim = length(ax)
125-
size′ = extend(size(a), 1, ndim)
126-
a′ = reshape(a, size′)
127-
target_size = length.(ax)
128-
fillsize = ntuple(ndim) do dim
129-
size′[dim] == target_size[dim] && return 1
130-
isone(size′[dim]) && return target_size[dim]
131-
return throw(DimensionMismatch("Dimensions $(axes(a)) and $ax don't match."))
132-
end
133-
return mortar(Fill(a′, fillsize))
134-
end
135-
136-
struct Mapped{Style<:Union{Nothing,BroadcastStyle},Axes,F,Args<:Tuple} <: AbstractMapped
137-
style::Style
138-
f::F
139-
args::Args
140-
axes::Axes
141-
function Mapped(style, f, args, axes)
142-
check_shape(args...)
143-
return new{typeof(style),typeof(axes),typeof(f),typeof(args)}(style, f, args, axes)
144-
end
145-
end
146-
147-
function Mapped(bc::Broadcasted)
148-
return Mapped(ExprStyle(bc), bc)
149-
end
150-
function Mapped(::NotMapExpr, bc::Broadcasted)
151-
f = map_function(bc)
152-
ax = axes(bc)
153-
f, args = promote_shape(ax, f, map_args(bc)...)
154-
return Mapped(bc.style, f, args, ax)
155-
end
156-
function Mapped(::MapExpr, bc::Broadcasted)
157-
f = bc.f
158-
ax = axes(bc)
159-
f, args = promote_shape(ax, f, bc.args...)
160-
return Mapped(bc.style, f, args, ax)
161-
end
162-
163-
function Broadcast.Broadcasted(m::Mapped)
164-
return Broadcasted(m.style, m.f, m.args, m.axes)
165-
end
166-
167-
function mapped(f, args...)
168-
check_shape(args...)
169-
return Mapped(broadcasted(f, args...))
170-
end
171-
172-
Base.similar(m::Mapped, elt::Type) = similar(Broadcasted(m), elt)
173-
Base.similar(m::Mapped, elt::Type, ax) = similar(Broadcasted(m), elt, ax)
174-
Base.axes(m::Mapped) = axes(Broadcasted(m))
175-
# Equivalent to:
176-
# map(m.f, m.args...)
177-
# copy(Broadcasted(m))
178-
function Base.copy(m::Mapped)
179-
elt = combine_eltypes(m.f, m.args)
180-
# TODO: Handle case of non-concrete eltype.
181-
@assert Base.isconcretetype(elt)
182-
return copyto!(similar(m, elt), m)
183-
end
184-
Base.copyto!(dest::AbstractArray, m::Mapped) = map!(m.f, dest, m.args...)
185-
Broadcast.instantiate(m::Mapped) = Mapped(instantiate(Broadcasted(m)))
3+
include("mapped.jl")
4+
include("linearcombination.jl")
1865

1876
end

src/linearcombination.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using Base.Broadcast: Broadcasted
2+
struct LinearCombination{C} <: Function
3+
coefficients::C
4+
end
5+
coefficients(a::LinearCombination) = a.coefficients
6+
function (f::LinearCombination)(args...)
7+
return mapreduce(*,+,coefficients(f),args)
8+
end
9+
10+
struct Summed{Style,N,C<:NTuple{N},A<:NTuple{N}}
11+
style::Style
12+
coefficients::C
13+
arguments::A
14+
end
15+
Summed(a::Summed) = a
16+
coefficients(a::Summed) = a.coefficients
17+
arguments(a::Summed) = a.arguments
18+
style(a::Summed) = a.style
19+
LinearCombination(a::Summed) = LinearCombination(coefficients(a))
20+
using Base.Broadcast: combine_axes
21+
Base.axes(a::Summed) = combine_axes(a.arguments...)
22+
function Base.eltype(a::Summed)
23+
cts = typeof.(coefficients(a))
24+
elts = eltype.(arguments(a))
25+
ts = map((ct, elt) -> Base.promote_op(*, ct, elt), cts, elts)
26+
return Base.promote_op(+, ts...)
27+
end
28+
function Base.getindex(a::Summed, I...)
29+
return mapreduce(+, coefficients(a), arguments(a)) do c, a
30+
return c * a[I...]
31+
end
32+
end
33+
using Base.Broadcast: combine_styles
34+
function Summed(coefficients::Tuple, arguments::Tuple)
35+
return Summed(combine_styles(arguments...), coefficients, arguments)
36+
end
37+
Summed(a) = Summed((one(eltype(a)),), (a,))
38+
function Base.:+(a::Summed, b::Summed)
39+
return Summed(
40+
(coefficients(a)..., coefficients(b)...), (arguments(a)..., arguments(b)...)
41+
)
42+
end
43+
Base.:-(a::Summed, b::Summed) = a + (-b)
44+
Base.:+(a::Summed, b::AbstractArray) = a + Summed(b)
45+
Base.:-(a::Summed, b::AbstractArray) = a - Summed(b)
46+
Base.:+(a::AbstractArray, b::Summed) = Summed(a) + b
47+
Base.:-(a::AbstractArray, b::Summed) = Summed(a) - b
48+
Base.:*(c::Number, a::Summed) = Summed(c .* coefficients(a), arguments(a))
49+
Base.:*(a::Summed, c::Number) = c * a
50+
Base.:/(a::Summed, c::Number) = Summed(coefficients(a) ./ c, arguments(a))
51+
Base.:-(a::Summed) = -one(eltype(a)) * a
52+
53+
Base.similar(a::Summed) = similar(a, eltype(a))
54+
Base.similar(a::Summed, elt::Type) = similar(a, elt, axes(a))
55+
Base.similar(a::Summed, ax::Tuple) = similar(a, eltype(a), ax)
56+
function Base.similar(a::Summed, elt::Type, ax::Tuple)
57+
return similar(Broadcasted(a), elt, ax)
58+
end
59+
Base.copy(a::Summed) = copyto!(similar(a), a)
60+
function Base.copyto!(dest::AbstractArray, a::Summed)
61+
return copyto!(dest, Broadcasted(a))
62+
end
63+
function Broadcast.Broadcasted(a::Summed)
64+
f = LinearCombination(a)
65+
return Broadcasted(style(a), f, arguments(a), axes(a))
66+
end
67+
68+
using Base.Broadcast: Broadcast
69+
Broadcast.BroadcastStyle(a::Type{<:Summed{<:Style}}) where {Style} = Style()
70+
Broadcast.broadcastable(a::Summed) = a
71+
Broadcast.materialize(a::Summed) = copy(a)
72+
Broadcast.materialize!(dest, a::Summed) = copyto!(dest, a)

0 commit comments

Comments
 (0)