Skip to content

Commit 4247bd1

Browse files
committed
proper type promotion for Add
1 parent e5cb78d commit 4247bd1

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

src/fast-terms.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ struct Add{X, T<:Number, D} <: Symbolic{X}
1818
sorted_args_cache::Ref{Any}
1919
end
2020

21-
function Add(coeff, dict)
21+
function Add(T, coeff, dict)
2222
if isempty(dict)
2323
return coeff
2424
elseif _iszero(coeff) && length(dict) == 1
2525
k,v = first(dict)
2626
return _isone(v) ? k : makemul(1, v, k)
2727
end
28-
Add{Number, typeof(coeff), typeof(dict)}(coeff,dict, Ref{Any}(nothing))
28+
29+
Add{T, typeof(coeff), typeof(dict)}(coeff, dict, Ref{Any}(nothing))
2930
end
3031

3132
symtype(a::Add{X}) where {X} = X
@@ -76,33 +77,42 @@ function makeadd(sign, coeff, xs...)
7677
coeff, d
7778
end
7879

80+
add_t(a,b) = promote_symtype(+, symtype(a), symtype(b))
81+
sub_t(a,b) = promote_symtype(-, symtype(a), symtype(b))
82+
sub_t(a) = promote_symtype(-, symtype(a))
83+
7984
function +(a::SN, b::SN)
8085
if a isa Add
8186
coeff, dict = makeadd(1, 0, b)
82-
return Add(a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero))
87+
T = promote_symtype(+, symtype(a), symtype(b))
88+
return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero))
8389
elseif b isa Add
8490
return b + a
8591
end
86-
Add(makeadd(1, 0, a, b)...)
92+
Add(add_t(a,b), makeadd(1, 0, a, b)...)
8793
end
8894

89-
+(a::Number, b::SN) = Add(makeadd(1, a, b)...)
95+
+(a::Number, b::SN) = Add(add_t(a,b), makeadd(1, a, b)...)
9096

91-
+(a::SN, b::Number) = Add(makeadd(1, b, a)...)
97+
+(a::SN, b::Number) = Add(add_t(a,b), makeadd(1, b, a)...)
9298

9399
+(a::SN) = a
94100

95-
+(a::Add, b::Add) = Add(a.coeff + b.coeff, _merge(+, a.dict, b.dict, filter=_iszero))
101+
+(a::Add, b::Add) = Add(add_t(a,b),
102+
a.coeff + b.coeff,
103+
_merge(+, a.dict, b.dict, filter=_iszero))
96104

97-
+(a::Number, b::Add) = iszero(a) ? b : Add(a + b.coeff, b.dict)
105+
+(a::Number, b::Add) = iszero(a) ? b : Add(add_t(a,b), a + b.coeff, b.dict)
98106

99-
+(b::Add, a::Number) = iszero(a) ? b : Add(a + b.coeff, b.dict)
107+
+(b::Add, a::Number) = iszero(a) ? b : Add(add_t(a,b), a + b.coeff, b.dict)
100108

101-
-(a::Add) = Add(-a.coeff, mapvalues((_,v) -> -v, a.dict))
109+
-(a::Add) = Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict))
102110

103-
-(a::SN) = Add(makeadd(-1, 0, a)...)
111+
-(a::SN) = Add(sub_t(a), makeadd(-1, 0, a)...)
104112

105-
-(a::Add, b::Add) = Add(a.coeff - b.coeff, _merge(-, a.dict, b.dict, filter=_iszero))
113+
-(a::Add, b::Add) = Add(sub_t(a,b),
114+
a.coeff - b.coeff,
115+
_merge(-, a.dict, b.dict, filter=_iszero))
106116

107117
-(a::SN, b::SN) = a + (-b)
108118

@@ -182,6 +192,9 @@ function makemul(sign, coeff, xs...; d=sdict())
182192
Mul(coeff, d)
183193
end
184194

195+
mul_t(a,b) = promote_symtype(*, symtype(a), symtype(b))
196+
mul_t(a) = promote_symtype(*, symtype(a))
197+
185198
*(a::SN) = a
186199

187200
*(a::SN, b::SN) = makemul(1, 1, a, b)

0 commit comments

Comments
 (0)