Skip to content

fix: fix type-instability of get_root_indp #1074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
28 changes: 23 additions & 5 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,30 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
end

function get_root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) !== indp
return get_root_indp(sc)
function get_root_indp(prob::AbstractSciMLProblem)
get_root_indp(prob.f)
end

function get_root_indp(f::T) where {T <: AbstractSciMLFunction}
if hasfield(T, :sys)
return f.sys
elseif hasfield(T, :f) && f.f isa AbstractSciMLFunction
return get_root_indp(f.f)
else
return nothing
end
return indp
end

function get_root_indp(prob::LinearProblem)
get_root_indp(prob.f)
end

get_root_indp(prob::AbstractJumpProblem) = get_root_indp(prob.prob)

get_root_indp(x) = x

function get_root_indp(f::SymbolicLinearInterface)
get_root_indp(f.sys)
end

# Everything from this point on is public API
Expand Down
26 changes: 26 additions & 0 deletions test/JET.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using NonlinearSolve
using LinearSolve
using LinearAlgebra
using ADTypes
using JET
const LS = LinearSolve

function f(u, p)
L, U = cholesky(p.Σ)
rhs = (u .* u .- p.λ)
linprob = LinearProblem(Matrix(L), rhs)
alg = LS.GenericLUFactorization()
sol = LinearSolve.solve(linprob, alg)
return sol.u
end

function minimize(λ=1.0)
ps = (; λ, Σ=hermitianpart(rand(2,2) + 2*I))
u₀ = rand(2)
prob = NonlinearLeastSquaresProblem{false}(f, u₀, ps)
autodiff = AutoForwardDiff(; chunksize=1)
sol = solve(prob, SimpleTrustRegion(; autodiff))
return sol.u
end

@test_opt minimize()
4 changes: 4 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand Down
4 changes: 2 additions & 2 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ push!(syss, nsys)
push!(probs, NonlinearProblem(nsys, [u0; p], jac = true))

rate₁ = β * x * y
affect₁ = [x ~ x - σ, y ~ y + σ]
affect₁ = [x ~ Pre(x) - σ, y ~ Pre(y) + σ]
rate₂ = ρ * y
affect₂ = [y ~ y - 1, z ~ z + 1]
affect₂ = [y ~ Pre(y) - 1, z ~ Pre(z) + 1]
j₁ = ConstantRateJump(rate₁, affect₁)
j₂ = ConstantRateJump(rate₂, affect₂)
j₃ = MassActionJump(2 * β + ρ, [z => 1], [x => 1, z => -1])
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ end
@time @safetestset "Aqua" begin
include("aqua.jl")
end
activate_downstream_env()
@time @safetestset "JET" begin
include("JET.jl")
end
end
if GROUP == "Core" || GROUP == "All"
@time @safetestset "Display" begin
Expand Down
Loading