-
-
Notifications
You must be signed in to change notification settings - Fork 72
Add test for ForwardDiff with StaticArrays #650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test for ForwardDiff with StaticArrays #650
Conversation
Tests fail |
My bad, forgot to import |
This won't work until I can find a way to make |
why do we need special problem types at all? For OrdinaryDiffEq, we don't have |
Ok so this is just blatant code duplication, but it fixes the problem. Is there a better way to do this? I just need to make sure that |
@oscardsmith const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
<:Union{<:SMatrix, <:SVector}} where {uType, iip}
function SciMLBase.solve(prob::StaticLinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
if alg === nothing || alg isa DirectLdiv!
u = prob.A \ prob.b
elseif alg isa LUFactorization
u = lu(prob.A) \ prob.b
elseif alg isa QRFactorization
u = qr(prob.A) \ prob.b
elseif alg isa CholeskyFactorization
u = cholesky(prob.A) \ prob.b
elseif alg isa NormalCholeskyFactorization
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
elseif alg isa SVDFactorization
u = svd(prob.A) \ prob.b
else
# Slower Path but handles all cases
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
return SciMLBase.build_linear_solution(
alg, u, nothing, prob; retcode = ReturnCode.Success)
end For the DualLinear stuff I just kind of followed the lead of NonlinearSolve, where they have a seperate Dual nonlinear problem and Dual nonlinear cache. I think it makes sense in this context though |
Shit that was a misclick, you should never |
But that's what happens in the fallback case? Line 334 in 51ff7a1
If we want to completely disallow |
How could you ever |
If you do something like this (not saying this is particularly smart or common) : using LinearSolve
using StaticArrays
using Krylov
A = @SMatrix rand(4,4)
b = @SVector rand(4)
prob = LinearProblem(A,b)
sol = solve(prob, KrylovJL_GMRES()) it goes through Lines 318 to 336 in 51ff7a1
So if you try to differentiate something like this: function f(p)
A = rand(4,4) .+ p
b = rand(4) .- p
A = SMatrix{4,4}(A)
b = SVector{4}(b)
prob = LinearProblem(A,b)
sol = solve(prob, KrylovJL_GMRES())
end
ForwardDiff.jacobian(f, [2.0]) it ends up going through |
Krylov doesn't make sense with static arrays. Any other case? |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
We should check to make sure that the ForwardDiff overloads don't apply to StaticArrays.