|
| 1 | +using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface |
| 2 | +using LinearAlgebra |
| 3 | +using Test |
| 4 | +import HomotopyContinuation |
| 5 | + |
| 6 | +@testset "No parameters" begin |
| 7 | + @variables x y z |
| 8 | + eqs = [0 ~ x^2 + y^2 + 2x * y |
| 9 | + 0 ~ x^2 + 4x + 4 |
| 10 | + 0 ~ y * z + 4x^2] |
| 11 | + @mtkbuild sys = NonlinearSystem(eqs) |
| 12 | + prob = HomotopyContinuationProblem(sys, [x => 1.0, y => 1.0, z => 1.0], []) |
| 13 | + @test prob[x] == prob[y] == prob[z] == 1.0 |
| 14 | + @test prob[x + y] == 2.0 |
| 15 | + sol = solve(prob; threading = false) |
| 16 | + @test SciMLBase.successful_retcode(sol) |
| 17 | + @test norm(sol.resid)≈0.0 atol=1e-10 |
| 18 | +end |
| 19 | + |
| 20 | +struct Wrapper |
| 21 | + x::Matrix{Float64} |
| 22 | +end |
| 23 | + |
| 24 | +@testset "Parameters" begin |
| 25 | + wrapper(w::Wrapper) = det(w.x) |
| 26 | + @register_symbolic wrapper(w::Wrapper) |
| 27 | + |
| 28 | + @variables x y z |
| 29 | + @parameters p q::Int r::Wrapper |
| 30 | + |
| 31 | + eqs = [0 ~ x^2 + y^2 + p * x * y |
| 32 | + 0 ~ x^2 + 4x + q |
| 33 | + 0 ~ y * z + 4x^2 + wrapper(r)] |
| 34 | + |
| 35 | + @mtkbuild sys = NonlinearSystem(eqs) |
| 36 | + prob = HomotopyContinuationProblem(sys, [x => 1.0, y => 1.0, z => 1.0], |
| 37 | + [p => 2.0, q => 4, r => Wrapper([1.0 1.0; 0.0 0.0])]) |
| 38 | + @test prob.ps[p] == 2.0 |
| 39 | + @test prob.ps[q] == 4 |
| 40 | + @test prob.ps[r].x == [1.0 1.0; 0.0 0.0] |
| 41 | + @test prob.ps[p * q] == 8.0 |
| 42 | + sol = solve(prob; threading = false) |
| 43 | + @test SciMLBase.successful_retcode(sol) |
| 44 | + @test norm(sol.resid)≈0.0 atol=1e-10 |
| 45 | +end |
| 46 | + |
| 47 | +@testset "Array variables" begin |
| 48 | + @variables x[1:3] |
| 49 | + @parameters p[1:3] |
| 50 | + _x = collect(x) |
| 51 | + eqs = collect(0 .~ vec(sum(_x * _x'; dims = 2)) + collect(p)) |
| 52 | + @mtkbuild sys = NonlinearSystem(eqs) |
| 53 | + prob = HomotopyContinuationProblem(sys, [x => ones(3)], [p => 1:3]) |
| 54 | + @test prob[x] == ones(3) |
| 55 | + @test prob[p + x] == [2, 3, 4] |
| 56 | + prob[x] = 2ones(3) |
| 57 | + @test prob[x] == 2ones(3) |
| 58 | + prob.ps[p] = [2, 3, 4] |
| 59 | + @test prob.ps[p] == [2, 3, 4] |
| 60 | + sol = @test_nowarn solve(prob; threading = false) |
| 61 | + @test sol.retcode == ReturnCode.ConvergenceFailure |
| 62 | +end |
0 commit comments