Skip to content

Conversation

jClugstor
Copy link
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

We should check to make sure that the ForwardDiff overloads don't apply to StaticArrays.

@ChrisRackauckas
Copy link
Member

Tests fail

@jClugstor
Copy link
Member Author

My bad, forgot to import

@jClugstor
Copy link
Member Author

This won't work until I can find a way to make init(::StaticLinearProblem) go through the normal init instead of the Dual init, even if the StaticLinearProblem is also a DualLinearProblem.

@oscardssmith
Copy link
Member

why do we need special problem types at all? For OrdinaryDiffEq, we don't have StaticODEProblem/DualADProblmen, we just handle static and dual cases through the normal problem.

@jClugstor
Copy link
Member Author

Ok so this is just blatant code duplication, but it fixes the problem. Is there a better way to do this? I just need to make sure that init on a StaticLinearProblem doesn't go through the init for DualLinearProblem. I know there's invoke etc. but that doesn't seem like a good idea either?

@jClugstor
Copy link
Member Author

@oscardsmith
What would be the alternative here?
For the StaticLinearProblem stuff, it needs to be able to bypass the other solver algorithms and just use the LinearAlgebra? functions, with the other algorithms as a fallback.
So the alternative would probably be to check if A or b are a static array, and then send it to a special solve or something? That's basically what it does anyway though right?

const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
    <:Union{<:SMatrix, <:SVector}} where {uType, iip}
    
function SciMLBase.solve(prob::StaticLinearProblem,
        alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
    if alg === nothing || alg isa DirectLdiv!
        u = prob.A \ prob.b
    elseif alg isa LUFactorization
        u = lu(prob.A) \ prob.b
    elseif alg isa QRFactorization
        u = qr(prob.A) \ prob.b
    elseif alg isa CholeskyFactorization
        u = cholesky(prob.A) \ prob.b
    elseif alg isa NormalCholeskyFactorization
        u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
    elseif alg isa SVDFactorization
        u = svd(prob.A) \ prob.b
    else
        # Slower Path but handles all cases
        cache = init(prob, alg, args...; kwargs...)
        return solve!(cache)
    end
    return SciMLBase.build_linear_solution(
        alg, u, nothing, prob; retcode = ReturnCode.Success)
end

For the DualLinear stuff I just kind of followed the lead of NonlinearSolve, where they have a seperate Dual nonlinear problem and Dual nonlinear cache. I think it makes sense in this context though
because in order to split up the primal and dual linear solves we need a cache that can hold the primal linear solve cache plus the information needed for the dual part of the solve.

@ChrisRackauckas ChrisRackauckas merged commit b8fdd8d into SciML:main Jul 30, 2025
35 of 39 checks passed
@ChrisRackauckas
Copy link
Member

Shit that was a misclick, you should never init a static problem.

@jClugstor
Copy link
Member Author

you should never init a static problem.

But that's what happens in the fallback case?

cache = init(prob, alg, args...; kwargs...)

If we want to completely disallow init for a StaticLinearProblem, I can set up an error for that fallback case, and an error for init(::StaticLinearProblem ...), but wouldn't that be breaking?

@ChrisRackauckas
Copy link
Member

How could you ever init it? The dispatches all go straight to solving so it never makes a mutable cache and that's on purpose. It would only ever happen if you specifically choose to init it, which wouldn't make too much sense for most use cases.

@jClugstor
Copy link
Member Author

If you do something like this (not saying this is particularly smart or common) :

using LinearSolve
using StaticArrays
using Krylov

A = @SMatrix rand(4,4)
b = @SVector rand(4)

prob = LinearProblem(A,b)

sol = solve(prob, KrylovJL_GMRES())

it goes through init because of the fallback here:

function SciMLBase.solve(prob::StaticLinearProblem,
alg::SciMLLinearSolveAlgorithm, args...; kwargs...)
if alg === nothing || alg isa DirectLdiv!
u = prob.A \ prob.b
elseif alg isa LUFactorization
u = lu(prob.A) \ prob.b
elseif alg isa QRFactorization
u = qr(prob.A) \ prob.b
elseif alg isa CholeskyFactorization
u = cholesky(prob.A) \ prob.b
elseif alg isa NormalCholeskyFactorization
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
elseif alg isa SVDFactorization
u = svd(prob.A) \ prob.b
else
# Slower Path but handles all cases
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
.

So if you try to differentiate something like this:

function f(p)
    A = rand(4,4) .+ p
    b = rand(4) .- p
    A = SMatrix{4,4}(A)
    b = SVector{4}(b)
    prob = LinearProblem(A,b)
    
    sol = solve(prob, KrylovJL_GMRES())
end

ForwardDiff.jacobian(f, [2.0])

it ends up going through init, creating a DualLinearCache, and then fails during solve!. I mean this doesn't work on previous version either so I guess it probably didn't break anything.

@ChrisRackauckas
Copy link
Member

Krylov doesn't make sense with static arrays. Any other case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants