Skip to content

Commit 36f1875

Browse files
committed
Add missing file
1 parent 3f162f3 commit 36f1875

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

src/linearcombination.jl

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using Base.Broadcast: Broadcasted
2+
struct LinearCombination{C} <: Function
3+
coefficients::C
4+
end
5+
coefficients(a::LinearCombination) = a.coefficients
6+
function (f::LinearCombination)(args...)
7+
return mapreduce(*,+,coefficients(f),args)
8+
end
9+
10+
struct Sum{Style,C<:Tuple,A<:Tuple}
11+
style::Style
12+
coefficients::C
13+
arguments::A
14+
end
15+
coefficients(a::Sum) = a.coefficients
16+
arguments(a::Sum) = a.arguments
17+
style(a::Sum) = a.style
18+
LinearCombination(a::Sum) = LinearCombination(coefficients(a))
19+
using Base.Broadcast: combine_axes
20+
Base.axes(a::Sum) = combine_axes(a.arguments...)
21+
function Base.eltype(a::Sum)
22+
cts = typeof.(coefficients(a))
23+
elts = eltype.(arguments(a))
24+
ts = map((ct, elt) -> Base.promote_op(*, ct, elt), cts, elts)
25+
return Base.promote_op(+, ts...)
26+
end
27+
using Base.Broadcast: combine_styles
28+
function Sum(coefficients::Tuple, arguments::Tuple)
29+
return Sum(combine_styles(arguments...), coefficients, arguments)
30+
end
31+
Sum(a) = Sum((one(eltype(a)),), (a,))
32+
function Base.:+(a::Sum, b::Sum)
33+
return Sum((coefficients(a)..., coefficients(b)...), (arguments(a)..., arguments(b)...))
34+
end
35+
Base.:-(a::Sum, b::Sum) = a + (-b)
36+
Base.:+(a::Sum, b::AbstractArray) = a + Sum(b)
37+
Base.:-(a::Sum, b::AbstractArray) = a - Sum(b)
38+
Base.:+(a::AbstractArray, b::Sum) = Sum(a) + b
39+
Base.:-(a::AbstractArray, b::Sum) = Sum(a) - b
40+
Base.:*(c::Number, a::Sum) = Sum(c .* coefficients(a), arguments(a))
41+
Base.:*(a::Sum, c::Number) = c * a
42+
Base.:/(a::Sum, c::Number) = Sum(coefficients(a) ./ c, arguments(a))
43+
Base.:-(a::Sum) = -1 * a
44+
45+
function Base.copy(a::Sum)
46+
return copyto!(similar(a), a)
47+
end
48+
Base.similar(a::Sum) = similar(a, eltype(a))
49+
Base.similar(a::Sum, elt::Type) = similar(a, elt, axes(a))
50+
function Base.copyto!(dest::AbstractArray, a::Sum)
51+
f = LinearCombination(a)
52+
dest .= f.(arguments(a)...)
53+
return dest
54+
end
55+
function Broadcast.Broadcasted(a::Sum)
56+
f = LinearCombination(a)
57+
return Broadcasted(style(a), f, arguments(a), axes(a))
58+
end
59+
function Base.similar(a::Sum, elt::Type, ax::Tuple)
60+
return similar(Broadcasted(a), elt, ax)
61+
end
62+
63+
using Base.Broadcast: Broadcast, AbstractArrayStyle, DefaultArrayStyle
64+
Broadcast.materialize(a::Sum) = copy(a)
65+
Broadcast.materialize!(dest, a::Sum) = copyto!(dest, a)
66+
struct SumStyle <: AbstractArrayStyle{Any} end
67+
Broadcast.broadcastable(a::Sum) = a
68+
Broadcast.BroadcastStyle(::Type{<:Sum}) = SumStyle()
69+
Broadcast.BroadcastStyle(style::SumStyle, ::AbstractArrayStyle) = style
70+
# Fix ambiguity error with Base.
71+
Broadcast.BroadcastStyle(style::SumStyle, ::DefaultArrayStyle) = style
72+
function Broadcast.broadcasted(::SumStyle, f, as...)
73+
return error("Arbitrary broadcasting not supported for SumStyle.")
74+
end
75+
function Broadcast.broadcasted(::SumStyle, ::typeof(+), a, b::Sum)
76+
return Sum(a) + b
77+
end
78+
function Broadcast.broadcasted(::SumStyle, ::typeof(+), a::Sum, b)
79+
return a + Sum(b)
80+
end
81+
function Broadcast.broadcasted(::SumStyle, ::typeof(+), a::Sum, b::Sum)
82+
return a + b
83+
end
84+
function Broadcast.broadcasted(::SumStyle, ::typeof(*), c::Number, a)
85+
return c * Sum(a)
86+
end
87+
function Broadcast.broadcasted(::SumStyle, ::typeof(*), c::Number, a::Sum)
88+
return c * a
89+
end
90+
function Broadcast.broadcasted(::SumStyle, ::typeof(/), a::Sum, c::Number)
91+
return Sum(a) / c
92+
end

0 commit comments

Comments
 (0)