Skip to content

Commit 756ae48

Browse files
refactor: improved n-ary addition
1 parent 1e46b57 commit 756ae48

File tree

1 file changed

+62
-29
lines changed

1 file changed

+62
-29
lines changed

src/types.jl

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,39 +1346,72 @@ sub_t(a,b) = promote_symtype(-, symtype(a), symtype(b))
13461346
sub_t(a) = promote_symtype(-, symtype(a))
13471347

13481348
import Base: (+), (-), (*), (//), (/), (\), (^)
1349-
function +(a::SN, b::SN)
1350-
!issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata
1351-
if isadd(a) && isadd(b)
1352-
return Add(add_t(a,b),
1353-
a.coeff + b.coeff,
1354-
_merge(+, a.dict, b.dict, filter=_iszero))
1355-
elseif isadd(a)
1356-
coeff, dict = makeadd(1, 0, b)
1357-
return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero))
1358-
elseif isadd(b)
1359-
return b + a
1360-
end
1361-
coeff, dict = makeadd(1, 0, a, b)
1362-
Add(add_t(a,b), coeff, dict)
1363-
end
1364-
1365-
function +(a::Number, b::SN)
1366-
tmp = unwrap(a)
1367-
if tmp !== a
1368-
return tmp + b
1369-
end
1370-
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
1371-
iszero(a) && return b
1372-
if isadd(b)
1373-
Add(add_t(a,b), a + b.coeff, b.dict)
1349+
1350+
function +(a::SN, bs::Union{SN, Number}...)
1351+
isempty(bs) && return a
1352+
# entries where `!issafecanon`
1353+
unsafes = SmallV{Any}()
1354+
# coeff and dict of the `Add`
1355+
coeff = 0
1356+
dict = sdict()
1357+
# type of the `Add`
1358+
T = symtype(a)
1359+
1360+
# handle `a` separately
1361+
if issafecanon(+, a)
1362+
if isadd(a)
1363+
coeff = a.coeff
1364+
dict = copy(a.dict)
1365+
elseif ismul(a)
1366+
v = a.coeff
1367+
a′ = Mul(symtype(a), 1, copy(a.dict); metadata = a.metadata)
1368+
dict[a′] = v
1369+
else
1370+
dict[a] = 1
1371+
end
13741372
else
1375-
Add(add_t(a,b), makeadd(1, a, b)...)
1373+
push!(unsafes, a)
1374+
end
1375+
1376+
for b in bs
1377+
T = promote_symtype(+, T, symtype(b))
1378+
if !issafecanon(+, b)
1379+
push!(unsafes, b)
1380+
continue
1381+
end
1382+
if b isa Number
1383+
coeff += b
1384+
continue
1385+
end
1386+
if isadd(b)
1387+
coeff += b.coeff
1388+
for (k, v) in b.dict
1389+
dict[k] = get(dict, k, 0) + v
1390+
end
1391+
elseif ismul(b)
1392+
v = b.coeff
1393+
b′ = Mul(symtype(b), 1, copy(b.dict); metadata = b.metadata)
1394+
dict[b′] = get(dict, b′, 0) + v
1395+
else
1396+
dict[b] = get(dict, b, 0) + 1
1397+
end
1398+
end
1399+
# remove entries multiplied by zero
1400+
filter!(dict) do kvp
1401+
!iszero(kvp[2])
13761402
end
1377-
end
13781403

1379-
+(a::SN, b::Number) = b + a
1404+
result = isempty(dict) ? coeff : Add(T, coeff, dict)
1405+
if !isempty(unsafes)
1406+
push!(unsafes, result)
1407+
result = Term{T}(+, unsafes)
1408+
end
1409+
return result
1410+
end
13801411

1381-
+(a::SN) = a
1412+
function +(a::Number, b::SN, bs::Union{SN, Number}...)
1413+
return +(b, a, bs...)
1414+
end
13821415

13831416
function -(a::SN)
13841417
!issafecanon(*, a) && return term(-, a)

0 commit comments

Comments
 (0)