Skip to content

Commit f643117

Browse files
committed
Merge branch 'master' into manual-smallfixes
2 parents e892f97 + 142ea18 commit f643117

File tree

9 files changed

+197
-47
lines changed

9 files changed

+197
-47
lines changed

.github/workflows/Downstream.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
name: IntegrationTest
3+
on:
4+
push:
5+
branches: [master]
6+
tags: [v*]
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
12+
runs-on: ${{ matrix.os }}
13+
env:
14+
GROUP: ${{ matrix.package.group }}
15+
strategy:
16+
fail-fast: false
17+
matrix:
18+
julia-version: [1]
19+
os: [ubuntu-latest]
20+
package:
21+
- {user: SciML, repo: ModelingToolkit.jl, group: All}
22+
- {user: SciML, repo: Catalyst.jl, group: All}
23+
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
24+
- {user: SciML, repo: DataDrivenDiffEq.jl, group: Standard}
25+
- {user: JuliaSymbolics, repo: Symbolics.jl, group: All}
26+
27+
steps:
28+
- uses: actions/checkout@v2
29+
- uses: julia-actions/setup-julia@v1
30+
with:
31+
version: ${{ matrix.julia-version }}
32+
arch: x64
33+
- uses: julia-actions/julia-buildpkg@latest
34+
- name: Clone Downstream
35+
uses: actions/checkout@v2
36+
with:
37+
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
38+
path: downstream
39+
- name: Load this and run the downstream tests
40+
shell: julia --color=yes --project=downstream {0}
41+
run: |
42+
using Pkg
43+
try
44+
# force it to use this PR's version of the package
45+
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
46+
Pkg.update()
47+
Pkg.test() # resolver may fail with test time deps
48+
catch err
49+
err isa Pkg.Resolve.ResolverError || rethrow()
50+
# If we can't resolve that means this is incompatible by SemVer and this is fine
51+
# It means we marked this as a breaking change, so we don't need to worry about
52+
# Mistakenly introducing a breaking change, as we have intentionally made one
53+
@info "Not compatible with this release. No problem." exception=err
54+
exit(0) # Exit immediately, as a success
55+
end

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymbolicUtils"
22
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
33
authors = ["Shashi Gowda"]
4-
version = "0.8.4"
4+
version = "0.9.3"
55

66
[deps]
77
AbstractAlgebra = "c3fe647b-3220-5bb0-a1ea-a7954cac585d"
@@ -19,7 +19,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1919
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2020

2121
[compat]
22-
AbstractAlgebra = "0.9, 0.10, 0.11, 0.12, 0.13"
22+
AbstractAlgebra = "0.9, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15"
2323
AbstractTrees = "0.3"
2424
Combinatorics = "1.0"
2525
ConstructionBase = "1.1"

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export @syms, term, showraw, hasmetadata, getmetadata, setmetadata
77
using DataStructures
88
using Setfield
99
import Setfield: PropertyLens
10-
import Base: +, -, *, /, \, ^, ImmutableDict
10+
import Base: +, -, *, /, //, \, ^, ImmutableDict
1111
using ConstructionBase
1212
include("types.jl")
1313

src/api.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
"""
44
```julia
5-
simplify(x; rewriter=default_simplifier(),
5+
simplify(x; polynorm=false,
66
threaded=false,
7-
polynorm=true,
8-
thread_subtree_cutoff=100)
7+
thread_subtree_cutoff=100,
8+
rewriter=nothing)
99
```
1010
1111
Simplify an expression (`x`) by applying `rewriter` until there are no changes.
@@ -60,3 +60,24 @@ function substitute(expr, dict; fold=true)
6060
expr
6161
end
6262
end
63+
64+
"""
65+
occursin(needle::Symbolic, haystack::Symbolic)
66+
67+
Determine whether the second argument contains the first argument. Note that
68+
this function doesn't handle associativity, commutativity, or distributivity.
69+
"""
70+
Base.occursin(needle::Symbolic, haystack::Symbolic) = _occursin(needle, haystack)
71+
Base.occursin(needle, haystack::Symbolic) = _occursin(needle, haystack)
72+
Base.occursin(needle::Symbolic, haystack) = _occursin(needle, haystack)
73+
function _occursin(needle, haystack)
74+
isequal(needle, haystack) && return true
75+
76+
if istree(haystack)
77+
args = arguments(haystack)
78+
for arg in args
79+
occursin(needle, arg) && return true
80+
end
81+
end
82+
return false
83+
end

src/code.jl

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,39 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2])
9696

9797
toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))
9898

99-
function toexpr(O, st)
100-
!istree(O) && return O
101-
op = operation(O)
99+
function_to_expr(op, args, st) = nothing
100+
101+
function function_to_expr(::typeof(^), O, st)
102102
args = arguments(O)
103-
if op === (^) && length(args) == 2 && args[2] isa Number && args[2] < 0
103+
if length(args) == 2 && args[2] isa Number && args[2] < 0
104104
ex = args[1]
105105
if args[2] == -1
106106
return toexpr(Term{Any}(inv, [ex]), st)
107107
else
108108
return toexpr(Term{Any}(^, [Term{Any}(inv, [ex]), -args[2]]), st)
109109
end
110-
elseif op === (SymbolicUtils.ifelse)
111-
return :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
112-
elseif op isa Sym && haskey(st.symbolify, O)
113-
return st.symbolify[O]
114110
end
115-
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
111+
return nothing
112+
end
113+
114+
function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
115+
args = arguments(O)
116+
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
117+
end
118+
119+
function_to_expr(::Sym, O, st) = get(st.symbolify, O, nothing)
120+
121+
function toexpr(O, st)
122+
!istree(O) && return O
123+
op = operation(O)
124+
expr′ = function_to_expr(op, O, st)
125+
if expr′ !== nothing
126+
return expr′
127+
else
128+
haskey(st.symbolify, O) && return st.symbolify[O]
129+
args = arguments(O)
130+
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
131+
end
116132
end
117133

118134
# Call elements of vector arguments by their name.
@@ -464,23 +480,33 @@ end
464480

465481
struct Multithreaded end
466482
"""
467-
SpawnFetch{ParallelType}(exprs, reduce)
483+
SpawnFetch{ParallelType}(funcs [, args], reduce)
468484
469-
Run every expr in `exprs` in its own task, and use the `reduce`
470-
function to combine the results of executing `exprs`.
485+
Run every expression in `funcs` in its own task, the expression
486+
should be a `Func` object and is passed to `Threads.Task(f)`.
487+
If `Func` takes arguments, then the arguments must be passed in as `args`--a vector of vector of arguments to each function in `funcs`. We don't use `@spawn` in order to support RuntimeGeneratedFunctions which disallow closures, instead we interpolate these functions or closures as smaller RuntimeGeneratedFunctions.
488+
489+
`reduce` function is used to combine the results of executing `exprs`. A SpawnFetch expression returns the reduced result.
490+
491+
492+
Use `Symbolics.MultithreadedForm` ParallelType from the Symbolics.jl package to get the RuntimeGeneratedFunction version SpawnFetch.
471493
472494
`ParallelType` can be used to define more parallelism types
473495
SymbolicUtils supports `Multithreaded` type. Which spawns
474496
threaded tasks.
475497
"""
476498
struct SpawnFetch{Typ}
477499
exprs::Vector
500+
args::Union{Nothing, Vector}
478501
combine
479502
end
480503

504+
(::Type{SpawnFetch{T}})(exprs, combine) where {T} = SpawnFetch{T}(exprs, nothing, combine)
505+
481506
function toexpr(p::SpawnFetch{Multithreaded}, st)
482-
spawns = map(p.exprs) do thunk
483-
:(Base.Threads.@spawn $(toexpr(thunk, st)))
507+
args = isnothing(p.args) ? Iterators.repeated((), length(p.exprs)) : p.args
508+
spawns = map(p.exprs, args) do thunk, xs
509+
:(Base.Threads.@spawn $(toexpr(thunk, st))($(toexpr.(xs, (st,))...)))
484510
end
485511
quote
486512
$(toexpr(p.combine, st))(map(fetch, ($(spawns...),))...)

src/types.jl

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function getmetadata(s::Symbolic, ctx)
6969
end
7070

7171
function getmetadata(s::Symbolic, ctx, default)
72-
s.metadata isa Ref ? get(s.metadata[], ctx, default) : default
72+
s.metadata isa AbstractDict ? get(s.metadata, ctx, default) : default
7373
end
7474

7575
# pirated for Setfield purposes:
@@ -173,9 +173,9 @@ end
173173

174174
Base.nameof(s::Sym) = s.name
175175

176-
ConstructionBase.constructorof(s::Type{<:Sym{T}}) where {T} = Sym{T}
176+
ConstructionBase.constructorof(s::Type{<:Sym{T}}) where {T} = (n,m) -> Sym{T}(n, metadata=m)
177177

178-
function (::Type{Sym{T}})(name, metadata=NO_METADATA) where {T}
178+
function (::Type{Sym{T}})(name; metadata=NO_METADATA) where {T}
179179
Sym{T, typeof(metadata)}(name, metadata)
180180
end
181181

@@ -327,14 +327,14 @@ function ConstructionBase.constructorof(s::Type{<:Term{T}}) where {T}
327327
end
328328
end
329329

330-
function (::Type{Term{T}})(f, args, metadata=NO_METADATA) where {T}
330+
function (::Type{Term{T}})(f, args; metadata=NO_METADATA) where {T}
331331
Term{T, typeof(metadata)}(f, args, metadata, Ref{UInt}(0))
332332
end
333333

334334
istree(t::Term) = true
335335

336-
function Term(f, args, metadata=NO_METADATA)
337-
Term{_promote_symtype(f, args)}(f, args, metadata)
336+
function Term(f, args; metadata=NO_METADATA)
337+
Term{_promote_symtype(f, args)}(f, args, metadata=metadata)
338338
end
339339

340340
operation(x::Term) = getfield(x, :f)
@@ -422,13 +422,20 @@ end
422422
setargs(t, args) = Term{symtype(t)}(operation(t), args)
423423
cdrargs(args) = setargs(t, cdr(args))
424424

425-
print_arg(io, x::Union{Complex, Rational}) = print(io, "(", x, ")")
426-
print_arg(io, x) = print(io, x)
427-
print_arg(io, f::typeof(^), x) = print_arg(IOContext(io, :paren=>true), x)
425+
print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")")
426+
isbinop(f) = istree(f) && Base.isbinaryoperator(nameof(operation(f)))
427+
function print_arg(io, x; paren=false)
428+
if paren && isbinop(x)
429+
print(io, "(", x, ")")
430+
else
431+
print(io, x)
432+
end
433+
end
434+
print_arg(io, s::String; paren=true) = show(io, s)
428435
function print_arg(io, f, x)
429436
f !== (*) && return print_arg(io, x)
430-
if istree(x) && Base.isbinaryoperator(nameof(operation(x)))
431-
print_arg(IOContext(io, :paren=>true), x)
437+
if Base.isbinaryoperator(nameof(f))
438+
print_arg(io, x, paren=true)
432439
else
433440
print_arg(io, x)
434441
end
@@ -447,11 +454,19 @@ function show_add(io, args)
447454
print_arg(io, -, t)
448455
else
449456
print(io, " - ")
450-
print_arg(IOContext(io, :paren=>true), +, -t)
457+
print_arg(io, -t, paren=true)
451458
end
452459
end
453460
end
454461

462+
function show_pow(io, args)
463+
base, ex = args
464+
465+
print_arg(io, base, paren=true)
466+
print(io, "^")
467+
print_arg(io, ex, paren=true)
468+
end
469+
455470
function show_mul(io, args)
456471
length(args) == 1 && return print_arg(io, *, args[1])
457472

@@ -480,8 +495,8 @@ function show_call(io, f, args)
480495
binary = Base.isbinaryoperator(fname)
481496
if binary
482497
for (i, t) in enumerate(args)
483-
i != 1 && print(io, fname == :^ ? fname : " $fname ")
484-
print_arg(io, (^), t)
498+
i != 1 && print(io, " $fname ")
499+
print_arg(io, t)
485500
end
486501
else
487502
if f isa Sym
@@ -506,15 +521,15 @@ function show_term(io::IO, t)
506521
f = operation(t)
507522
args = arguments(t)
508523

509-
get(io, :paren, false) && print(io, "(")
510524
if f === (+)
511525
show_add(io, args)
512526
elseif f === (*)
513527
show_mul(io, args)
528+
elseif f === (^)
529+
show_pow(io, args)
514530
else
515531
show_call(io, f, args)
516532
end
517-
get(io, :paren, false) && print(io, ")")
518533

519534
return nothing
520535
end
@@ -714,7 +729,7 @@ operation(a::Mul) = *
714729

715730
function arguments(a::Mul)
716731
a.sorted_args_cache[] !== nothing && return a.sorted_args_cache[]
717-
args = sort!([k^v for (k,v) in a.dict], lt=<ₑ)
732+
args = sort!([Pow(k, v) for (k,v) in a.dict], lt=<ₑ)
718733
a.sorted_args_cache[] = isone(a.coeff) ? args : vcat(a.coeff, args)
719734
end
720735

@@ -766,6 +781,10 @@ mul_t(a) = promote_symtype(*, symtype(a))
766781

767782
/(a::SN, b::Number) = inv(b) * a
768783

784+
//(a::Union{SN, Number}, b::SN) = a / b
785+
786+
//(a::SN, b::T) where {T <: Number} = (one(T) // b) * a
787+
769788
"""
770789
Pow(base, exp)
771790
@@ -887,7 +906,17 @@ struct TreePrint
887906
x
888907
end
889908
AbstractTrees.children(x::Term) = arguments(x)
890-
AbstractTrees.children(x::Union{Add, Mul}) = map(y->TreePrint(x isa Add ? (:*) : (:^), y), collect(pairs(x.dict)))
909+
function AbstractTrees.children(x::Union{Add, Mul})
910+
children = Any[x.coeff]
911+
for (key, coeff) in pairs(x.dict)
912+
if coeff == 1
913+
push!(children, key)
914+
else
915+
push!(children, TreePrint(x isa Add ? (:*) : (:^), (key, coeff)))
916+
end
917+
end
918+
return children
919+
end
891920
AbstractTrees.children(x::Union{Pow}) = [x.base, x.exp]
892921
AbstractTrees.children(x::TreePrint) = [x.x[1], x.x[2]]
893922

test/basics.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ end
112112
@test SymbolicUtils.promote_symtype(ifelse, Bool, Int, Bool) == Union{Int, Bool}
113113
@test_throws MethodError w < 0
114114
@test isequal(w == 0, Term{Bool}(==, [w, 0]))
115+
116+
@eqtest x // 5 == (1 // 5) * x
117+
@eqtest x // Int16(5) == Rational{Int16}(1, 5) * x
118+
@eqtest 5 // x == 5 / x
119+
@eqtest x // a == x / a
115120
end
116121

117122
@testset "err test" begin
@@ -127,6 +132,13 @@ end
127132
@test substitute(exp(a), Dict(a=>2)) exp(2)
128133
end
129134

135+
@testset "occursin" begin
136+
@syms a b c
137+
@test occursin(a, a + b)
138+
@test !occursin(sin(a), a + b + c)
139+
@test occursin(sin(a), a * b + c + sin(a^2 * sin(a)))
140+
end
141+
130142
@testset "printing" begin
131143
@syms a b c
132144
@test repr(a+b) == "a + b"

0 commit comments

Comments
 (0)