Skip to content

Commit 7ea4fad

Browse files
Revert "refactor: improved n-ary addition"
This reverts commit 756ae48.
1 parent a066a04 commit 7ea4fad

File tree

1 file changed

+29
-62
lines changed

1 file changed

+29
-62
lines changed

src/types.jl

Lines changed: 29 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,72 +1346,39 @@ sub_t(a,b) = promote_symtype(-, symtype(a), symtype(b))
13461346
sub_t(a) = promote_symtype(-, symtype(a))
13471347

13481348
import Base: (+), (-), (*), (//), (/), (\), (^)
1349-
1350-
function +(a::SN, bs::Union{SN, Number}...)
1351-
isempty(bs) && return a
1352-
# entries where `!issafecanon`
1353-
unsafes = SmallV{Any}()
1354-
# coeff and dict of the `Add`
1355-
coeff = 0
1356-
dict = sdict()
1357-
# type of the `Add`
1358-
T = symtype(a)
1359-
1360-
# handle `a` separately
1361-
if issafecanon(+, a)
1362-
if isadd(a)
1363-
coeff = a.coeff
1364-
dict = copy(a.dict)
1365-
elseif ismul(a)
1366-
v = a.coeff
1367-
a′ = Mul(symtype(a), 1, copy(a.dict); metadata = a.metadata)
1368-
dict[a′] = v
1369-
else
1370-
dict[a] = 1
1371-
end
1372-
else
1373-
push!(unsafes, a)
1374-
end
1375-
1376-
for b in bs
1377-
T = promote_symtype(+, T, symtype(b))
1378-
if !issafecanon(+, b)
1379-
push!(unsafes, b)
1380-
continue
1381-
end
1382-
if b isa Number
1383-
coeff += b
1384-
continue
1385-
end
1386-
if isadd(b)
1387-
coeff += b.coeff
1388-
for (k, v) in b.dict
1389-
dict[k] = get(dict, k, 0) + v
1390-
end
1391-
elseif ismul(b)
1392-
v = b.coeff
1393-
b′ = Mul(symtype(b), 1, copy(b.dict); metadata = b.metadata)
1394-
dict[b′] = get(dict, b′, 0) + v
1395-
else
1396-
dict[b] = get(dict, b, 0) + 1
1397-
end
1398-
end
1399-
# remove entries multiplied by zero
1400-
filter!(dict) do kvp
1401-
!iszero(kvp[2])
1349+
function +(a::SN, b::SN)
1350+
!issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata
1351+
if isadd(a) && isadd(b)
1352+
return Add(add_t(a,b),
1353+
a.coeff + b.coeff,
1354+
_merge(+, a.dict, b.dict, filter=_iszero))
1355+
elseif isadd(a)
1356+
coeff, dict = makeadd(1, 0, b)
1357+
return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero))
1358+
elseif isadd(b)
1359+
return b + a
1360+
end
1361+
coeff, dict = makeadd(1, 0, a, b)
1362+
Add(add_t(a,b), coeff, dict)
1363+
end
1364+
1365+
function +(a::Number, b::SN)
1366+
tmp = unwrap(a)
1367+
if tmp !== a
1368+
return tmp + b
14021369
end
1403-
1404-
result = isempty(dict) ? coeff : Add(T, coeff, dict)
1405-
if !isempty(unsafes)
1406-
push!(unsafes, result)
1407-
result = Term{T}(+, unsafes)
1370+
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
1371+
iszero(a) && return b
1372+
if isadd(b)
1373+
Add(add_t(a,b), a + b.coeff, b.dict)
1374+
else
1375+
Add(add_t(a,b), makeadd(1, a, b)...)
14081376
end
1409-
return result
14101377
end
14111378

1412-
function +(a::Number, b::SN, bs::Union{SN, Number}...)
1413-
return +(b, a, bs...)
1414-
end
1379+
+(a::SN, b::Number) = b + a
1380+
1381+
+(a::SN) = a
14151382

14161383
function -(a::SN)
14171384
!issafecanon(*, a) && return term(-, a)

0 commit comments

Comments
 (0)