Skip to content

Commit 490011e

Browse files
authored
Merge pull request #480 from bowenszhu/similarterm
Call `BasicSymbolic` arithmetic operations to construct `similarterm`
2 parents 6964c76 + 86de974 commit 490011e

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

src/types.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -535,15 +535,12 @@ function TermInterface.similarterm(t::Type{<:BasicSymbolic{<:Number}}, f, args,
535535
if T === nothing
536536
T = _promote_symtype(f, args)
537537
end
538-
if f === (+)
539-
Add(T, makeadd(1, 0, args...)...; metadata=metadata)
540-
elseif f == (*)
541-
Mul(T, makemul(1, args...)...; metadata=metadata)
542-
elseif f == (/)
543-
@assert length(args) == 2
544-
Div{T}(args...; metadata=metadata)
545-
elseif f == (^) && length(args) == 2
546-
Pow{T}(makepow(args...)...; metadata=metadata)
538+
if f in (+, *) || (f in (/, ^) && length(args) == 2)
539+
res = f(args...)
540+
if res isa Symbolic
541+
@set! res.metadata = metadata
542+
end
543+
return res
547544
else
548545
Term{T}(f, args, metadata=metadata)
549546
end

test/basics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,18 @@ end
191191
@syms a b c
192192
@test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1))
193193
@test isequal(SymbolicUtils.similarterm(b^2, ^, [b^2, 1//2]), b)
194+
195+
# test that similarterm doesn't hard-code BasicSymbolic subtype
196+
# and is consistent with BasicSymbolic arithmetic operations
197+
@test isequal(SymbolicUtils.similarterm(a / b, *, [a / b, c]), (a / b) * c)
198+
@test isequal(SymbolicUtils.similarterm(a * b, *, [0, c]), 0)
199+
@test isequal(SymbolicUtils.similarterm(a^b, ^, [a * b, 3]), (a * b)^3)
200+
201+
# test that similarterm sets metadata correctly
202+
metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1")
203+
s = SymbolicUtils.similarterm(a^b, ^, [a * b, 3]; metadata = metadata)
204+
@test hasmetadata(s, Ctx1)
205+
@test getmetadata(s, Ctx1) == "meta_1"
194206
end
195207

196208
toterm(t) = Term{symtype(t)}(operation(t), arguments(t))

0 commit comments

Comments
 (0)