Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,11 @@ function SciMLBase.solve!(
end

# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray)
# If the property is A or b, also update it in the LinearCache
if sym === :A || sym === :b || sym === :u
setproperty!(dc.linear_cache, sym, nodual_value(val))
prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place
setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc.
elseif hasfield(DualLinearCache, sym)
setfield!(dc, sym, val)
elseif hasfield(LinearSolve.LinearCache, sym)
Expand All @@ -322,15 +323,15 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
# Update the partials and invalidate cache if setting A or b
if sym === :A
setfield!(dc, :dual_A, val)
setfield!(dc, :partials_A, partial_vals(val))
partial_vals!(getfield(dc, :partials_A), val) # Update in-place
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
elseif sym === :b
setfield!(dc, :dual_b, val)
setfield!(dc, :partials_b, partial_vals(val))
partial_vals!(getfield(dc, :partials_b), val) # Update in-place
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
elseif sym === :u
setfield!(dc, :dual_u, val)
setfield!(dc, :partials_u, partial_vals(val))
partial_vals!(getfield(dc, :partials_u), val) # Update in-place
end
end

Expand Down Expand Up @@ -360,30 +361,20 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
partial_vals(x) = nothing
partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place

# Add recursive handling for nested dual values
nodual_value(x) = x
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact

function nodual_value(x::AbstractArray{<:Dual})
# Create a similar array with the appropriate element type
T = typeof(nodual_value(first(x)))
result = similar(x, T)

# Fill the result array with values
for i in eachindex(x)
result[i] = nodual_value(x[i])
end

return result
end
nodual_value(x::AbstractArray{<:Dual}) = nodual_value!(similar(x, typeof(nodual_value(first(x)))), x)
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place

function update_partials_list!(partial_matrix::AbstractVector{T}, list_cache) where {T}
p = eachindex(first(partial_matrix))
for i in p
for j in eachindex(partial_matrix)
list_cache[i][j] = partial_matrix[j][i]
@inbounds list_cache[i][j] = partial_matrix[j][i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, but they gave me a non-negligible speedup, as I saw checkbounds showing up in profiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it so the user should not suffer a performance penalty from library internals. I can remove them if you want to keep it safe.

This function is just shuffling data around. The optimal solution would be to avoid it altogether, but I am not sure if it's easily possible.

end
end
return list_cache
Expand All @@ -396,7 +387,7 @@ function update_partials_list!(partial_matrix, list_cache)
for k in 1:p
for i in 1:m
for j in 1:n
list_cache[k][i, j] = partial_matrix[i, j][k]
@inbounds list_cache[k][i, j] = partial_matrix[i, j][k]
end
end
end
Expand Down
Loading