diff --git a/src/factorization.jl b/src/factorization.jl index 4b3e946a9..d5254ca22 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -281,6 +281,18 @@ function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = QRFactorization(pivot, 16, inplace) end +# Helper function to ensure QRPivoted has Int64 pivot indices +function _ensure_int64_pivot(fact::LinearAlgebra.QRPivoted) + if eltype(fact.p) !== Int64 + # Convert to Int64 if needed + return LinearAlgebra.QRPivoted(fact.factors, fact.τ, convert(Vector{Int64}, fact.p)) + end + return fact +end + +# For non-pivoted QR, just return as-is +_ensure_int64_pivot(fact) = fact + function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) if ArrayInterface.can_setindex(typeof(A)) @@ -296,22 +308,26 @@ function do_factorization(alg::QRFactorization, A, b, u) else fact = qr(A, alg.pivot) end - return fact + # Ensure Int64 pivot indices for QRPivoted + return _ensure_int64_pivot(fact) end function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot) + fact = ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot) + return _ensure_int64_pivot(fact) end function init_cacheval(alg::QRFactorization, A::Symmetric{<:Number, <:Array}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - return qr(convert(AbstractMatrix, A), alg.pivot) + fact = qr(convert(AbstractMatrix, A), alg.pivot) + return _ensure_int64_pivot(fact) end -const PREALLOCATED_QR_ColumnNorm = ArrayInterface.qr_instance(rand(1, 1), ColumnNorm()) +# Create with Int64 pivot indices from the start +const PREALLOCATED_QR_ColumnNorm = _ensure_int64_pivot(ArrayInterface.qr_instance(rand(1, 1), ColumnNorm())) function init_cacheval(alg::QRFactorization{ColumnNorm}, A::Matrix{Float64}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -322,7 +338,8 @@ function init_cacheval( alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) A isa GPUArraysCore.AnyGPUArray && return qr(A) - return qr(A, alg.pivot) + fact = qr(A, alg.pivot) + return _ensure_int64_pivot(fact) end const PREALLOCATED_QR_NoPivot = ArrayInterface.qr_instance(rand(1, 1))