Skip to content

Commit 128606f

Browse files
authored
Merge pull request #169 from JuliaSymbolics/s/_promote_symtype
use mapfoldl instead of rec_promote_symtype, fix it on Sym{<:FnType}
2 parents ee3a702 + f48c9ff commit 128606f

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/methods.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,6 @@ for f in monadic
101101
@eval (::$(typeof(f)))(a::Symbolic) = term($f, a)
102102
end
103103

104-
rec_promote_symtype(f) = promote_symtype(f)
105-
rec_promote_symtype(f, x) = promote_symtype(f, x)
106-
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
107-
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)
108-
109-
110104
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
111105
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
112106

src/types.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ end
286286

287287
istree(t::Term) = true
288288

289-
Term(f, args) = Term{rec_promote_symtype(f, map(symtype, args)...)}(f, args)
289+
Term(f, args) = Term{_promote_symtype(f, args)}(f, args)
290290

291291
operation(x::Term) = getfield(x, :f)
292292

@@ -301,9 +301,24 @@ function Base.hash(t::Term{T}, salt::UInt) where {T}
301301
t.hash[] = hashvec(arguments(t), hash(operation(t), hash(T, salt)))
302302
end
303303

304+
_promote_symtype(f::Sym, args) = promote_symtype(f, map(symtype, args)...)
305+
function _promote_symtype(f, args)
306+
if length(args) == 0
307+
promote_symtype(f)
308+
elseif length(args) == 1
309+
promote_symtype(f, symtype(args[1]))
310+
elseif length(args) == 2
311+
promote_symtype(f, symtype(args[1]), symtype(args[2]))
312+
else
313+
# TODO: maybe restrict it only to functions that are Associative
314+
mapfoldl(symtype, (x,y) -> promote_symtype(f, x, y), args)
315+
end
316+
end
317+
318+
304319
function term(f, args...; type = nothing)
305320
if type === nothing
306-
T = rec_promote_symtype(f, map(symtype, args)...)
321+
T = _promote_symtype(f, args)
307322
else
308323
T = type
309324
end
@@ -703,10 +718,10 @@ end
703718

704719
function similarterm(p::Union{Mul, Add, Pow}, f, args)
705720
if f === (+)
706-
T = rec_promote_symtype(f, map(symtype, args)...)
721+
T = _promote_symtype(f, args)
707722
Add(T, makeadd(1, 0, args...)...)
708723
elseif f == (*)
709-
T = rec_promote_symtype(f, map(symtype, args)...)
724+
T = _promote_symtype(f, args)
710725
Mul(T, makemul(1, args...)...)
711726
elseif f == (^) && length(args) == 2
712727
Pow(args...)

0 commit comments

Comments
 (0)