|
1 | 1 | module MapBroadcast |
2 | | -# Convert broadcast call to map call by capturing array arguments |
3 | | -# with `map_args` and creating a map function with `map_function`. |
4 | | -# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl. |
5 | 2 |
|
6 | | -using Base.Broadcast: |
7 | | - Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate |
8 | | -using BlockArrays: mortar |
9 | | -using Compat: allequal |
10 | | -using FillArrays: Fill |
11 | | - |
12 | | -const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} |
13 | | - |
14 | | -# Get the arguments of the map expression that |
15 | | -# is equivalent to the broadcast expression. |
16 | | -function map_args(bc::Broadcasted) |
17 | | - return map_args_flatten(bc) |
18 | | -end |
19 | | - |
20 | | -function map_args_flatten(bc::Broadcasted, args_rest...) |
21 | | - return (map_args_flatten(bc.args...)..., map_args_flatten(args_rest...)...) |
22 | | -end |
23 | | -function map_args_flatten(arg1::AbstractArray, args_rest...) |
24 | | - return (arg1, map_args_flatten(args_rest...)...) |
25 | | -end |
26 | | -map_args_flatten(arg1, args_rest...) = map_args_flatten(args_rest...) |
27 | | -map_args_flatten() = () |
28 | | - |
29 | | -struct MapFunction{F,Args<:Tuple} <: Function |
30 | | - f::F |
31 | | - args::Args |
32 | | -end |
33 | | -struct Arg end |
34 | | - |
35 | | -# Get the function of the map expression that |
36 | | -# is equivalent to the broadcast expression. |
37 | | -# Returns a `MapFunction`. |
38 | | -function map_function(bc::Broadcasted) |
39 | | - return map_function_arg(bc) |
40 | | -end |
41 | | -map_function_args(args::Tuple{}) = args |
42 | | -function map_function_args(args::Tuple) |
43 | | - return (map_function_arg(args[1]), map_function_args(Base.tail(args))...) |
44 | | -end |
45 | | -function map_function_arg(bc::Broadcasted) |
46 | | - return MapFunction(bc.f, map_function_args(bc.args)) |
47 | | -end |
48 | | -map_function_arg(a::WrappedScalarArgs) = a[] |
49 | | -map_function_arg(a::AbstractArray) = Arg() |
50 | | -map_function_arg(a) = a |
51 | | - |
52 | | -# Evaluate MapFunction |
53 | | -(f::MapFunction)(args...) = apply_arg(f, args)[1] |
54 | | -function apply_arg(f::MapFunction, args) |
55 | | - mapfunction_args, args′ = apply_args(f.args, args) |
56 | | - return f.f(mapfunction_args...), args′ |
57 | | -end |
58 | | -apply_arg(mapfunction_arg::Arg, args) = args[1], Base.tail(args) |
59 | | -apply_arg(mapfunction_arg, args) = mapfunction_arg, args |
60 | | -function apply_args(mapfunction_args::Tuple, args) |
61 | | - mapfunction_args1, args′ = apply_arg(mapfunction_args[1], args) |
62 | | - mapfunction_args_rest, args′′ = apply_args(Base.tail(mapfunction_args), args′) |
63 | | - return (mapfunction_args1, mapfunction_args_rest...), args′′ |
64 | | -end |
65 | | -apply_args(mapfunction_args::Tuple{}, args) = mapfunction_args, args |
66 | | - |
67 | | -is_map_expr_or_arg(arg::AbstractArray) = true |
68 | | -is_map_expr_or_arg(arg::Any) = false |
69 | | -function is_map_expr_or_arg(bc::Broadcasted) |
70 | | - return all(is_map_expr_or_arg, bc.args) |
71 | | -end |
72 | | -function is_map_expr(bc::Broadcasted) |
73 | | - return is_map_expr_or_arg(bc) |
74 | | -end |
75 | | - |
76 | | -abstract type ExprStyle end |
77 | | -struct MapExpr <: ExprStyle end |
78 | | -struct NotMapExpr <: ExprStyle end |
79 | | - |
80 | | -ExprStyle(bc::Broadcasted) = is_map_expr(bc) ? MapExpr() : NotMapExpr() |
81 | | - |
82 | | -abstract type AbstractMapped <: Base.AbstractBroadcasted end |
83 | | - |
84 | | -function check_shape(::Type{Bool}, args...) |
85 | | - return allequal(axes, args) |
86 | | -end |
87 | | -function check_shape(args...) |
88 | | - if !check_shape(Bool, args...) |
89 | | - throw(DimensionMismatch("Mismatched shapes $(axes.(args)).")) |
90 | | - end |
91 | | - return nothing |
92 | | -end |
93 | | - |
94 | | -# Promote the shape of the arguments to support broadcasting |
95 | | -# over dimensions by expanding singleton dimensions. |
96 | | -function promote_shape(ax, f, args::AbstractArray...) |
97 | | - if allequal((ax, axes.(args)...)) |
98 | | - return f, args |
99 | | - end |
100 | | - return f, promote_shape_tile(ax, args...) |
101 | | -end |
102 | | -function promote_shape_tile(common_axes, args::AbstractArray...) |
103 | | - return map(arg -> tile(arg, common_axes), args) |
104 | | -end |
105 | | - |
106 | | -# Catch the case of zero arguments, like `a .= 2`. |
107 | | -function promote_shape(ax, f) |
108 | | - return identity, (Fill(f(), ax),) |
109 | | -end |
110 | | - |
111 | | -# Extend by repeating value up to length. |
112 | | -function extend(t::Tuple, value, length) |
113 | | - return ntuple(i -> get(t, i, value), length) |
114 | | -end |
115 | | - |
116 | | -# Handles logic of expanding singleton dimensions |
117 | | -# to match an array shape in broadcasting |
118 | | -# by tiling or repeating the input array. |
119 | | -function tile(a::AbstractArray, ax) |
120 | | - axes(a) == ax && return a |
121 | | - # Must be one-based for now. |
122 | | - @assert all(isone, first.(ax)) |
123 | | - @assert all(isone, first.(axes(a))) |
124 | | - ndim = length(ax) |
125 | | - size′ = extend(size(a), 1, ndim) |
126 | | - a′ = reshape(a, size′) |
127 | | - target_size = length.(ax) |
128 | | - fillsize = ntuple(ndim) do dim |
129 | | - size′[dim] == target_size[dim] && return 1 |
130 | | - isone(size′[dim]) && return target_size[dim] |
131 | | - return throw(DimensionMismatch("Dimensions $(axes(a)) and $ax don't match.")) |
132 | | - end |
133 | | - return mortar(Fill(a′, fillsize)) |
134 | | -end |
135 | | - |
136 | | -struct Mapped{Style<:Union{Nothing,BroadcastStyle},Axes,F,Args<:Tuple} <: AbstractMapped |
137 | | - style::Style |
138 | | - f::F |
139 | | - args::Args |
140 | | - axes::Axes |
141 | | - function Mapped(style, f, args, axes) |
142 | | - check_shape(args...) |
143 | | - return new{typeof(style),typeof(axes),typeof(f),typeof(args)}(style, f, args, axes) |
144 | | - end |
145 | | -end |
146 | | - |
147 | | -function Mapped(bc::Broadcasted) |
148 | | - return Mapped(ExprStyle(bc), bc) |
149 | | -end |
150 | | -function Mapped(::NotMapExpr, bc::Broadcasted) |
151 | | - f = map_function(bc) |
152 | | - ax = axes(bc) |
153 | | - f, args = promote_shape(ax, f, map_args(bc)...) |
154 | | - return Mapped(bc.style, f, args, ax) |
155 | | -end |
156 | | -function Mapped(::MapExpr, bc::Broadcasted) |
157 | | - f = bc.f |
158 | | - ax = axes(bc) |
159 | | - f, args = promote_shape(ax, f, bc.args...) |
160 | | - return Mapped(bc.style, f, args, ax) |
161 | | -end |
162 | | - |
163 | | -function Broadcast.Broadcasted(m::Mapped) |
164 | | - return Broadcasted(m.style, m.f, m.args, m.axes) |
165 | | -end |
166 | | - |
167 | | -function mapped(f, args...) |
168 | | - check_shape(args...) |
169 | | - return Mapped(broadcasted(f, args...)) |
170 | | -end |
171 | | - |
172 | | -Base.similar(m::Mapped, elt::Type) = similar(Broadcasted(m), elt) |
173 | | -Base.similar(m::Mapped, elt::Type, ax) = similar(Broadcasted(m), elt, ax) |
174 | | -Base.axes(m::Mapped) = axes(Broadcasted(m)) |
175 | | -# Equivalent to: |
176 | | -# map(m.f, m.args...) |
177 | | -# copy(Broadcasted(m)) |
178 | | -function Base.copy(m::Mapped) |
179 | | - elt = combine_eltypes(m.f, m.args) |
180 | | - # TODO: Handle case of non-concrete eltype. |
181 | | - @assert Base.isconcretetype(elt) |
182 | | - return copyto!(similar(m, elt), m) |
183 | | -end |
184 | | -Base.copyto!(dest::AbstractArray, m::Mapped) = map!(m.f, dest, m.args...) |
185 | | -Broadcast.instantiate(m::Mapped) = Mapped(instantiate(Broadcasted(m))) |
| 3 | +include("mapped.jl") |
| 4 | +include("linearcombination.jl") |
186 | 5 |
|
187 | 6 | end |
0 commit comments