Skip to content

Commit c0db446

Browse files
authored
add SymPyPythonCall extension; test extensions (#398)
* add SymPyPythonCall extension; test extensions * test SymPyPythonCall * learn to spell * adjust compat bounds * punt * oops * another tweak
1 parent 7433442 commit c0db446

File tree

4 files changed

+84
-6
lines changed

4 files changed

+84
-6
lines changed

Project.toml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,29 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
88
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1010

11+
12+
[weakdeps]
13+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
14+
IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
15+
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
16+
SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
17+
1118
[compat]
1219
ChainRulesCore = "1"
1320
CommonSolve = "0.1, 0.2"
1421
ForwardDiff = "0.10"
1522
IntervalRootFinding = "0.5"
1623
SymPy = "1"
24+
SymPyPythonCall = "0.1,1"
1725
Setfield = "0.7, 0.8, 1"
1826
julia = "1.0"
1927

28+
2029
[extensions]
2130
RootsForwardDiffExt = "ForwardDiff"
2231
RootsIntervalRootFindingExt = "IntervalRootFinding"
2332
RootsSymPyExt = "SymPy"
33+
RootsSymPyPythonCallExt = "SymPyPythonCall"
2434

2535
[extras]
2636
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
@@ -33,14 +43,10 @@ Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
3343
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3444
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3545
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
46+
SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
3647
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3748
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
3849
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3950

4051
[targets]
41-
test = ["Aqua", "ChainRulesTestUtils", "JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "SymPy", "Unitful", "Zygote"]
42-
43-
[weakdeps]
44-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
45-
IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
46-
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
52+
test = ["Aqua", "ChainRulesTestUtils", "JSON", "SpecialFunctions", "Statistics", "Test", "BenchmarkTools", "ForwardDiff", "Polynomials", "Unitful", "Zygote", "IntervalRootFinding"]

ext/RootsSymPyPythonCallExt.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module RootsSymPyPythonCallExt
2+
3+
using Roots
4+
using SymPyPythonCall
5+
6+
## Allow equations to specify a problem to solve
7+
function Roots.Callable_Function(M::Roots.AbstractUnivariateZeroMethod, f::SymPyPythonCall.Sym, p=nothing)
8+
if f.is_Equality == true
9+
f = lhs(f) - rhs(f)
10+
end
11+
Roots.Callable_Function(M, lambdify(f), p)
12+
end
13+
14+
15+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ include("./test_simple.jl")
3232

3333
include("./test_composable.jl")
3434
VERSION >= v"1.6.0" && include("./test_allocations.jl")
35+
VERSION >= v"1.9.0" && include("./test_extensions.jl")
36+
3537

3638
#include("./runbenchmarks.jl")
3739
#include("./test_derivative_free_interactive.jl")

test/test_extensions.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#=
2+
using SymPy
3+
@testset "SymPy" begin
4+
SymPy.@syms x
5+
@test find_zero(cos(x) ~ 1/2, (0, pi/2)) ≈ find_zero(x -> cos(x) - 1/2, (0, pi/2))
6+
@test find_zero(1/2 ~ cos(x), (0, pi/2)) ≈ find_zero(x -> 1/2 - cos(x), (0, pi/2))
7+
@test find_zero(cos(x) ~ x/2, (0, pi/2)) ≈ find_zero(x -> cos(x) - x/2, (0, pi/2))
8+
end
9+
=#
10+
11+
#=
12+
using SymPyPythonCall
13+
@testset "SymPythonCall" begin
14+
SymPyPythonCall.@syms x
15+
@test find_zero(cos(x) ~ 1/2, (0, pi/2)) ≈ find_zero(x -> cos(x) - 1/2, (0, pi/2))
16+
@test find_zero(1/2 ~ cos(x), (0, pi/2)) ≈ find_zero(x -> 1/2 - cos(x), (0, pi/2))
17+
@test find_zero(cos(x) ~ x/2, (0, pi/2)) ≈ find_zero(x -> cos(x) - x/2, (0, pi/2))
18+
end
19+
=#
20+
21+
using ForwardDiff
22+
@testset "ForwardDiff" begin
23+
f(x, p) = x^2 - p
24+
Z = ZeroProblem(f, (0, 1000))
25+
F(p) = solve(Z, Roots.Bisection(), p)
26+
for p (3,5,7,11)
27+
@test F(p) sqrt(p)
28+
@test ForwardDiff.derivative(F, p) 1 / (2sqrt(p))
29+
end
30+
31+
# Hessian is *broken*
32+
f(x, p) = x^2 - sum(p.^2)
33+
Z = ZeroProblem(f, (0, 1000))
34+
F(p) = solve(Z, Roots.Bisection(), p)
35+
Z = ZeroProblem(f, (0, 1000))
36+
F(p) = solve(Z, Roots.Bisection(), p)
37+
hess(f, p) = ForwardDiff.jacobian(p -> ForwardDiff.gradient(F, p), p)
38+
for p ([1,2], [1,3], [1,4])
39+
@test F(p) sqrt(sum(p.^2))
40+
@test_throws DimensionMismatch ForwardDiff.hessian(F, p)
41+
a, b = p
42+
n = sqrt(a^2 + b^2)^3
43+
@test hess(F, p) [b^2 -a*b; -a*b a^2]/n
44+
end
45+
46+
end
47+
48+
using IntervalRootFinding
49+
@testset "IntervalRootFinding" begin
50+
f(x) = sin(x + sin(x + sin(x)))
51+
@test find_zeros(f, (-5, 5)) [-pi, 0, pi]
52+
out = find_zeros(f, -5..5, Roots.Newton())
53+
@test sort(out.zeros) sort([-pi,0,pi])
54+
@test isempty(out.unknown)
55+
end

0 commit comments

Comments
 (0)