Skip to content

Commit 18f7ed6

Browse files
shashiYingboMa
andcommitted
delete rec_promote_symtype, use foldl and fast paths, fix it on Sym{FnType}
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 7d60e4a commit 18f7ed6

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/methods.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,6 @@ for f in monadic
101101
@eval (::$(typeof(f)))(a::Symbolic) = term($f, a)
102102
end
103103

104-
rec_promote_symtype(f) = promote_symtype(f)
105-
rec_promote_symtype(f, x) = promote_symtype(f, x)
106-
rec_promote_symtype(f, x,y) = promote_symtype(f, x,y)
107-
rec_promote_symtype(f, x,y,z...) = rec_promote_symtype(f, promote_symtype(f, x,y), z...)
108-
109-
110104
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
111105
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
112106

src/types.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ end
284284

285285
istree(t::Term) = true
286286

287-
Term(f, args) = Term{rec_promote_symtype(f, map(symtype, args)...)}(f, args)
287+
Term(f, args) = Term{_promote_symtype(f, args)}(f, args)
288288

289289
operation(x::Term) = getfield(x, :f)
290290

@@ -297,9 +297,22 @@ function Base.hash(t::Term{T}, salt::UInt) where {T}
297297
hashvec(arguments(t), hash(operation(t), hash(T, salt)))
298298
end
299299

300+
_promote_symtype(f::Sym, args) = promote_symtype(f, map(symtype, args)...)
301+
function _promote_symtype(f, args)
302+
if length(args) == 1
303+
promote_symtype(f, symtype(args[1]))
304+
elseif length(args) == 2
305+
promote_symtype(f, symtype(args[1]), symtype(args[2]))
306+
else
307+
# TODO: maybe restrict it only to functions that are Associative
308+
mapfoldl(symtype, (x,y) -> promote_symtype(f, x, y), args)
309+
end
310+
end
311+
312+
300313
function term(f, args...; type = nothing)
301314
if type === nothing
302-
T = rec_promote_symtype(f, map(symtype, args)...)
315+
T = _promote_symtype(f, args)
303316
else
304317
T = type
305318
end
@@ -689,10 +702,10 @@ end
689702

690703
function similarterm(p::Union{Mul, Add, Pow}, f, args)
691704
if f === (+)
692-
T = rec_promote_symtype(f, map(symtype, args)...)
705+
T = _promote_symtype(f, args)
693706
Add(T, makeadd(1, 0, args...)...)
694707
elseif f == (*)
695-
T = rec_promote_symtype(f, map(symtype, args)...)
708+
T = _promote_symtype(f, args)
696709
Mul(T, makemul(1, args...)...)
697710
elseif f == (^) && length(args) == 2
698711
Pow(args...)

0 commit comments

Comments
 (0)