Skip to content

Commit 8ab757e

Browse files
authored
Merge pull request #195 from JuliaSymbolics/myb/rec
Fix for non-associative functions
2 parents e91c00b + 18a54f5 commit 8ab757e

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/types.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ function Base.hash(t::Term{T}, salt::UInt) where {T}
290290
return h′
291291
end
292292

293+
isassociative(::Any) = false
294+
isassociative(::Union{typeof(+),typeof(*)}) = true
295+
293296
_promote_symtype(f::Sym, args) = promote_symtype(f, map(symtype, args)...)
294297
function _promote_symtype(f, args)
295298
if length(args) == 0
@@ -298,9 +301,10 @@ function _promote_symtype(f, args)
298301
promote_symtype(f, symtype(args[1]))
299302
elseif length(args) == 2
300303
promote_symtype(f, symtype(args[1]), symtype(args[2]))
301-
else
302-
# TODO: maybe restrict it only to functions that are Associative
304+
elseif isassociative(f)
303305
mapfoldl(symtype, (x,y) -> promote_symtype(f, x, y), args)
306+
else
307+
promote_symtype(f, map(symtype, args)...)
304308
end
305309
end
306310

test/basics.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ end
6666
@test isequal(a + x, Add(Number, 0, Dict(a=>1, x=>1)))
6767
@test isequal(a + z, Add(Number, 0, Dict(a=>1, z=>1)))
6868

69+
foo(w, z, a, b) = 1.0
70+
SymbolicUtils.promote_symtype(::typeof(foo), args...) = Real
71+
@test SymbolicUtils._promote_symtype(foo, (w, z, a, b,)) === Real
72+
6973
# promote_symtype of identity
7074
@test isequal(Term(identity, [w]), Term{Complex}(identity, [w]))
7175
@test isequal(+(w), w)

0 commit comments

Comments
 (0)