Skip to content

Commit 9c09e7e

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 7a968a1 + 423b640 commit 9c09e7e

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

page/interface.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ for `simplify` to work. Other required methods are `operation` and `istree`
4040

4141
In addition, the methods for `Base.hash` and `Base.isequal` should also be implemented by the types for the purposes of substitution and equality matching respectively.
4242

43-
### Optional
44-
45-
#### `similarterm(t::MyType, f, args)`
43+
#### `similarterm(t::MyType, f, args[, T])`
4644

47-
Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation).
45+
Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation). T is the symtype of the output term. You can use `promote_symtype` to infer this type.
4846

4947
The below two functions are internal to SymbolicUtils
5048

49+
### Optional
50+
5151
#### `symtype(x)`
5252

5353
The supposed type of values in the domain of x. Tracing tools can use this type to

src/abstractalgebra.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ function labels!(dicts, t)
2323
return t
2424
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
2525
tt = arguments(t)
26-
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt))
26+
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt), symtype(t))
2727
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
28-
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
28+
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)), symtype(t))
2929
else
3030
sym2term, term2sym = dicts
3131
if haskey(term2sym, t)
@@ -36,7 +36,8 @@ function labels!(dicts, t)
3636
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
3737
dicts2 = _dicts(dicts[2])
3838
sym2term[sym] = similarterm(t, operation(t),
39-
map(x->to_mpoly(x, dicts)[1], arguments(t)))
39+
map(x->to_mpoly(x, dicts)[1], arguments(t)),
40+
symtype(t))
4041
else
4142
sym = Sym{symtype(t)}(gensym("literal"))
4243
sym2term[sym] = t
@@ -110,7 +111,7 @@ function _to_term(reference, x::MPoly, dict, syms)
110111
elseif length(monics) == 0
111112
return 1
112113
else
113-
return similarterm(reference, *, monics)
114+
return similarterm(reference, *, monics, symtype(reference))
114115
end
115116
end
116117

@@ -123,15 +124,16 @@ function _to_term(reference, x::MPoly, dict, syms)
123124
t = similarterm(reference,
124125
+,
125126
map((x,y)->isone(y) ? x : Int(y)*x,
126-
monoms, x.coeffs[1:length(monoms)]))
127+
monoms, x.coeffs[1:length(monoms)]),
128+
symtype(reference))
127129
end
128130

129131
substitute(t, dict, fold=false)
130132
end
131133

132134
function _to_term(reference, x, dict, vars)
133135
if istree(x)
134-
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)))
136+
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)), symtype(x))
135137
else
136138
if haskey(dict, x)
137139
return dict[x]

src/types.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ end
7474
function to_symbolic(x)
7575
Base.depwarn("`to_symbolic(x)` is deprecated, define the interface for your " *
7676
"symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " *
77-
"and `similarterm(::YourType, f, args)`", :to_symbolic, force=true)
77+
"and `similarterm(::YourType, f, args, symtype)`", :to_symbolic, force=true)
7878

7979
x
8080
end
@@ -319,13 +319,23 @@ function term(f, args...; type = nothing)
319319
end
320320

321321
"""
322-
similarterm(t, f, args)
322+
similarterm(t, f, args, symtype)
323323
324-
Create a term that is similar in type to `t` such that `symtype(similarterm(f,
325-
args...)) === symtype(f(args...))`.
324+
Create a term that is similar in type to `t`. Extending this function allows packages
325+
using their own expression types with SymbolicUtils to define how new terms should
326+
be created.
327+
328+
## Arguments
329+
330+
- `t` the reference term to use to create similar terms
331+
- `f` is the operation of the term
332+
- `args` is the arguments
333+
- The `symtype` of the resulting term. Best effort will be made to set the symtype of the
334+
resulting similar term to this type.
326335
"""
327-
similarterm(t, f, args) = f(args...)
328-
similarterm(::Term, f, args) = term(f, args...)
336+
similarterm(t, f, args, symtype) = f(args...)
337+
similarterm(t, f, args) = similarterm(t, f, args, _promote_symtype(f, args))
338+
similarterm(::Term, f, args, symtype=nothing) = term(f, args...; type=symtype)
329339

330340
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1
331341

@@ -757,15 +767,16 @@ function mapvalues(f, d1::AbstractDict)
757767
d
758768
end
759769

760-
function similarterm(p::Union{Mul, Add, Pow}, f, args)
761-
if f === (+)
770+
function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
771+
if T === nothing
762772
T = _promote_symtype(f, args)
773+
end
774+
if f === (+)
763775
Add(T, makeadd(1, 0, args...)...)
764776
elseif f == (*)
765-
T = _promote_symtype(f, args)
766777
Mul(T, makemul(1, args...)...)
767778
elseif f == (^) && length(args) == 2
768-
Pow(args...)
779+
Pow{T, typeof.(args)...}(args...)
769780
else
770781
f(args...)
771782
end

test/nf.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
using SymbolicUtils, Test
2-
using SymbolicUtils: polynormalize, Term
2+
using SymbolicUtils: polynormalize, Term, symtype
33
@testset "polyform" begin
44
@syms a b c d
55
@test polynormalize(a * (b + -1 * c) + -1 * (b * a + -1 * c * a)) == 0
66
@eqtest polynormalize(sin(a+b)+sin(c+d)) == sin(a+b) + sin(c+d)
77
@eqtest simplify(polynormalize(sin((a+b)^2)^2)) == simplify(sin(a^2+2*(b*a)+b^2)^2)
88
@test simplify(polynormalize(sin((a+b)^2)^2 + cos((a+b)^2)^2)) == 1
9+
@syms x1::Real f(::Real)::Real
10+
11+
# issue 193
12+
@test isequal(polynormalize(f(x1 + 2.0)), f(2.0 + x1))
13+
@test symtype(polynormalize(f(x1 + 2.0))) == Real
914

1015
# cleanup rules
1116
@test polynormalize(Term{Number}(identity, 0)) == 0

0 commit comments

Comments
 (0)