Skip to content

Commit 277d6ea

Browse files
committed
proper type promotion for Mul and Pow
1 parent 4247bd1 commit 277d6ea

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/fast-terms.jl

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function makeadd(sign, coeff, xs...)
6262
continue
6363
end
6464
if x isa Mul
65-
k = Mul(1, x.dict)
65+
k = Mul(symtype(x), 1, x.dict)
6666
v = sign * x.coeff + get(d, k, 0)
6767
else
6868
k = x
@@ -133,7 +133,7 @@ struct Mul{X, T<:Number, D} <: Symbolic{X}
133133
sorted_args_cache::Ref{Any}
134134
end
135135

136-
function Mul(a,b)
136+
function Mul(T, a,b)
137137
isempty(b) && return a
138138
if _isone(a) && length(b) == 1
139139
pair = first(b)
@@ -143,7 +143,7 @@ function Mul(a,b)
143143
return Pow(first(pair), last(pair))
144144
end
145145
else
146-
Mul{Number, typeof(a), typeof(b)}(a,b, Ref{Any}(nothing))
146+
Mul{T, typeof(a), typeof(b)}(a,b, Ref{Any}(nothing))
147147
end
148148
end
149149

@@ -189,24 +189,26 @@ function makemul(sign, coeff, xs...; d=sdict())
189189
end
190190
end
191191
end
192-
Mul(coeff, d)
192+
(coeff, d)
193193
end
194194

195195
mul_t(a,b) = promote_symtype(*, symtype(a), symtype(b))
196196
mul_t(a) = promote_symtype(*, symtype(a))
197197

198198
*(a::SN) = a
199199

200-
*(a::SN, b::SN) = makemul(1, 1, a, b)
200+
*(a::SN, b::SN) = Mul(mul_t(a,b), makemul(1, 1, a, b)...)
201201

202-
*(a::Mul, b::Mul) = Mul(a.coeff * b.coeff, _merge(+, a.dict, b.dict, filter=_iszero))
202+
*(a::Mul, b::Mul) = Mul(mul_t(a, b),
203+
a.coeff * b.coeff,
204+
_merge(+, a.dict, b.dict, filter=_iszero))
203205

204-
*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : makemul(1,a, b)
206+
*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(1,a, b)...)
205207

206-
*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : makemul(1,a, b)
208+
*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(1,a, b)...)
207209

208210
function /(a::Union{SN,Number}, b::SN)
209-
a * makemul(-1, 1, b)
211+
a * Mul(promote_symtype(/, 1, symtype(b)), makemul(-1, 1, b)...)
210212
end
211213

212214
\(a::SN, b::Union{Number, SN}) = b / a
@@ -228,7 +230,7 @@ end
228230
function Pow(a,b)
229231
_iszero(b) && return 1
230232
_isone(b) && return a
231-
Pow{Number, typeof(a), typeof(b)}(a,b)
233+
Pow{promote_symtype(^, symtype(a), symtype(b)), typeof(a), typeof(b)}(a,b)
232234
end
233235

234236
symtype(a::Pow{X}) where {X} = X
@@ -252,14 +254,17 @@ Base.show(io::IO, p::Pow) = show_term(io, p)
252254
^(a::Number, b::SN) = Pow(a, b)
253255

254256
function ^(a::Mul, b::Number)
255-
Mul(a.coeff ^ b, mapvalues((k, v) -> b*v, a.dict))
257+
Mul(promote_symtype(^, symtype(a), symtype(b)),
258+
a.coeff ^ b, mapvalues((k, v) -> b*v, a.dict))
256259
end
257260

258261
function *(a::Mul, b::Pow)
259262
if b.exp isa Number
260-
Mul(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero))
263+
Mul(mul_t(a, b),
264+
a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero))
261265
else
262-
Mul(a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero))
266+
Mul(mul_t(a, b),
267+
a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero))
263268
end
264269
end
265270

@@ -292,9 +297,9 @@ end
292297

293298
function similarterm(p::Union{Mul, Add, Pow}, f, args)
294299
if f === (+)
295-
Add(makeadd(1, 0, args...)...)
300+
Add(symtype(p), makeadd(1, 0, args...)...)
296301
elseif f == (*)
297-
makemul(1, 1, args...)
302+
Mul(symtype(p), makemul(1, 1, args...)...)
298303
elseif f == (^)
299304
Pow(args...)
300305
else

0 commit comments

Comments
 (0)