Skip to content

Commit f6a1709

Browse files
YingboMashashi
andcommitted
Fix similarterm with Add
Co-authored-by: "Shashi Gowda" <[email protected]>
1 parent b813ac5 commit f6a1709

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/types.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,11 @@ and the key (in Add) should instead be used to store the actual coefficient
510510
function makeadd(sign, coeff, xs...)
511511
d = sdict()
512512
for x in xs
513+
if x isa Add
514+
coeff += x.coeff
515+
_merge!(+, d, x.dict, filter=_iszero)
516+
continue
517+
end
513518
if x isa Number
514519
coeff += x
515520
continue
@@ -631,7 +636,7 @@ function makemul(coeff, xs...; d=sdict())
631636
coeff *= x
632637
elseif x isa Mul
633638
coeff *= x.coeff
634-
d = _merge(+, d, x.dict, filter=_iszero)
639+
_merge!(+, d, x.dict, filter=_iszero)
635640
else
636641
v = 1 + get(d, x, 0)
637642
if _iszero(v)
@@ -721,8 +726,9 @@ end
721726

722727
*(a::Pow, b::Mul) = b * a
723728

724-
function _merge(f, d, others...; filter=x->false)
725-
acc = copy(d)
729+
_merge(f, d, others...; filter=x->false) = _merge!(f, copy(d), others...; filter=filter)
730+
function _merge!(f, d, others...; filter=x->false)
731+
acc = d
726732
for other in others
727733
for (k, v) in other
728734
v = f(v)
@@ -781,8 +787,10 @@ AbstractTrees.children(x::Term) = arguments(x)
781787
AbstractTrees.children(x::Union{Add, Mul}) = map(y->TreePrint(x isa Add ? (:*) : (:^), y), collect(pairs(x.dict)))
782788
AbstractTrees.children(x::Union{Pow}) = [x.base, x.exp]
783789
AbstractTrees.children(x::TreePrint) = [x.x[1], x.x[2]]
784-
function print_tree(x::Union{Term, Add, Mul, Pow})
785-
AbstractTrees.print_tree(stdout, x, withinds=true) do io, y, inds
790+
791+
print_tree(x; maxdepth=Inf, kw...) = print_tree(stdout, x; maxdepth=maxdepth, kw...)
792+
function print_tree(_io::IO, x::Union{Term, Add, Mul, Pow}; kw...)
793+
AbstractTrees.print_tree(_io, x; withinds=true, kw...) do io, y, inds
786794
if istree(y)
787795
print(io, operation(y))
788796
elseif y isa TreePrint

test/basics.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ end
115115
@test repr(a + -1*b) == "a - b"
116116
end
117117

118+
@testset "similarterm with Add" begin
119+
@syms a b c
120+
@test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1))
121+
end
122+
118123
toterm(t) = Term{symtype(t)}(operation(t), arguments(t))
119124

120125
@testset "diffs" begin

0 commit comments

Comments
 (0)