diff --git a/src/solutions/save_idxs.jl b/src/solutions/save_idxs.jl index 6e4afbd33..7a2db44e6 100644 --- a/src/solutions/save_idxs.jl +++ b/src/solutions/save_idxs.jl @@ -44,12 +44,30 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t) return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1)) end -function get_root_indp(indp) - if hasmethod(symbolic_container, Tuple{typeof(indp)}) && - (sc = symbolic_container(indp)) !== indp - return get_root_indp(sc) +function get_root_indp(prob::AbstractSciMLProblem) + get_root_indp(prob.f) +end + +function get_root_indp(f::T) where {T <: AbstractSciMLFunction} + if hasfield(T, :sys) + return f.sys + elseif hasfield(T, :f) && f.f isa AbstractSciMLFunction + return get_root_indp(f.f) + else + return nothing end - return indp +end + +function get_root_indp(prob::LinearProblem) + get_root_indp(prob.f) +end + +get_root_indp(prob::AbstractJumpProblem) = get_root_indp(prob.prob) + +get_root_indp(x) = x + +function get_root_indp(f::SymbolicLinearInterface) + get_root_indp(f.sys) end # Everything from this point on is public API diff --git a/test/JET.jl b/test/JET.jl new file mode 100644 index 000000000..d6a4c13da --- /dev/null +++ b/test/JET.jl @@ -0,0 +1,26 @@ +using NonlinearSolve +using LinearSolve +using LinearAlgebra +using ADTypes +using JET +const LS = LinearSolve + +function f(u, p) + L, U = cholesky(p.Σ) + rhs = (u .* u .- p.λ) + linprob = LinearProblem(Matrix(L), rhs) + alg = LS.GenericLUFactorization() + sol = LinearSolve.solve(linprob, alg) + return sol.u +end + +function minimize(λ=1.0) + ps = (; λ, Σ=hermitianpart(rand(2,2) + 2*I)) + u₀ = rand(2) + prob = NonlinearLeastSquaresProblem{false}(f, u₀, ps) + autodiff = AutoForwardDiff(; chunksize=1) + sol = solve(prob, SimpleTrustRegion(; autodiff)) + return sol.u +end + +@test_opt minimize() diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 4a6cb0368..202570276 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,10 +1,14 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index f96935e81..a0fd0ab52 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -50,9 +50,9 @@ push!(syss, nsys) push!(probs, NonlinearProblem(nsys, [u0; p], jac = true)) rate₁ = β * x * y -affect₁ = [x ~ x - σ, y ~ y + σ] +affect₁ = [x ~ Pre(x) - σ, y ~ Pre(y) + σ] rate₂ = ρ * y -affect₂ = [y ~ y - 1, z ~ z + 1] +affect₂ = [y ~ Pre(y) - 1, z ~ Pre(z) + 1] j₁ = ConstantRateJump(rate₁, affect₁) j₂ = ConstantRateJump(rate₂, affect₂) j₃ = MassActionJump(2 * β + ρ, [z => 1], [x => 1, z => -1]) diff --git a/test/runtests.jl b/test/runtests.jl index 3d0983005..02e5f2890 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,10 @@ end @time @safetestset "Aqua" begin include("aqua.jl") end + activate_downstream_env() + @time @safetestset "JET" begin + include("JET.jl") + end end if GROUP == "Core" || GROUP == "All" @time @safetestset "Display" begin