Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "Forw
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSparseMatrixColoringsExt = ["SparseMatrixColorings", "SparseArrays"]
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTrackerExt = "Tracker"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,10 @@ end

## Jacobian

struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG}
struct FastDifferentiationOneArgJacobianPrep{SIG,Y,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
_sig::Val{SIG}
y_prototype::Y
sparsity::P
jac_exe::E1
jac_exe!::E1!
end
Expand All @@ -376,14 +377,18 @@ function DI.prepare_jacobian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)
y_vec_var = myvec(y_var)
jac_var = if backend isa AutoSparse
sparse_jacobian(y_vec_var, x_vec_var)
if backend isa AutoSparse
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
sparsity = DI.get_pattern(jac_var)
else
jacobian(y_vec_var, x_vec_var)
jac_var = jacobian(y_vec_var, x_vec_var)
sparsity = nothing
end
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
return FastDifferentiationOneArgJacobianPrep(_sig, y_prototype, jac_exe, jac_exe!)
return FastDifferentiationOneArgJacobianPrep(
_sig, y_prototype, sparsity, jac_exe, jac_exe!
)
end

function DI.jacobian(
Expand Down Expand Up @@ -626,9 +631,10 @@ end

## Hessian

struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
struct FastDifferentiationHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG}
_sig::Val{SIG}
gradient_prep::G
sparsity::P
hess_exe::E2
hess_exe!::E2!
end
Expand All @@ -648,18 +654,22 @@ function DI.prepare_hessian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)

hess_var = if backend isa AutoSparse
sparse_hessian(y_var, x_vec_var)
if backend isa AutoSparse
hess_var = sparse_hessian(y_var, x_vec_var)
sparsity = DI.get_pattern(hess_var)
else
hessian(y_var, x_vec_var)
hess_var = hessian(y_var, x_vec_var)
sparsity = nothing
end
hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false)
hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true)

gradient_prep = DI.prepare_gradient_nokwarg(
strict, f, dense_ad(backend), x, contexts...
)
return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!)
return FastDifferentiationHessianPrep(
_sig, gradient_prep, sparsity, hess_exe, hess_exe!
)
end

function DI.hessian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ end

## Jacobian

struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
struct FastDifferentiationTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
_sig::Val{SIG}
sparsity::P
jac_exe::E1
jac_exe!::E1!
end
Expand All @@ -312,14 +313,16 @@ function DI.prepare_jacobian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)
y_vec_var = myvec(y_var)
jac_var = if backend isa AutoSparse
sparse_jacobian(y_vec_var, x_vec_var)
if backend isa AutoSparse
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
sparsity = DI.get_pattern(jac_var)
else
jacobian(y_vec_var, x_vec_var)
jac_var = jacobian(y_vec_var, x_vec_var)
sparsity = nothing
end
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
return FastDifferentiationTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
return FastDifferentiationTwoArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!)
end

function DI.value_and_jacobian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ module DifferentiationInterfaceSparseArraysExt
using ADTypes: ADTypes
using DifferentiationInterface
import DifferentiationInterface as DI
using SparseArrays: sparse
using SparseArrays: SparseMatrixCSC, sparse, nonzeros

function DI.get_pattern(M::SparseMatrixCSC)
S = similar(M, Bool)
nonzeros(S) .= true
return S
end

include("sparsity_detector.jl")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
decompress!
import SparseMatrixColorings as SMC

abstract type SparseJacobianPrep{SIG} <: DI.JacobianPrep{SIG} end
## SMC overloads

SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result)
SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result)
SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result)
SMC.ncolors(prep::SparseJacobianPrep) = ncolors(prep.coloring_result)
abstract type SMCSparseJacobianPrep{SIG} <: DI.SparseJacobianPrep{SIG} end

SMC.sparsity_pattern(prep::DI.SparseJacobianPrep) = prep.sparsity
SMC.column_colors(prep::DI.SparseJacobianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::DI.SparseJacobianPrep) = column_groups(prep.coloring_result)
SMC.row_colors(prep::DI.SparseJacobianPrep) = row_colors(prep.coloring_result)
SMC.row_groups(prep::DI.SparseJacobianPrep) = row_groups(prep.coloring_result)
SMC.ncolors(prep::DI.SparseJacobianPrep) = ncolors(prep.coloring_result)

Check warning on line 27 in DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl#L23-L27

Added lines #L23 - L27 were not covered by tests

SMC.sparsity_pattern(prep::DI.SparseHessianPrep) = prep.sparsity
SMC.column_colors(prep::DI.SparseHessianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::DI.SparseHessianPrep) = column_groups(prep.coloring_result)
SMC.ncolors(prep::DI.SparseHessianPrep) = ncolors(prep.coloring_result)

Check warning on line 32 in DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl#L30-L32

Added lines #L30 - L32 were not covered by tests

include("jacobian.jl")
include("jacobian_mixed.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
struct SparseHessianPrep{
struct SMCSparseHessianPrep{
SIG,
BS<:DI.BatchSizeSettings,
P<:AbstractMatrix,
C<:AbstractColoringResult{:symmetric,:column},
M<:AbstractMatrix{<:Number},
S<:AbstractVector{<:NTuple},
R<:AbstractVector{<:NTuple},
E2<:DI.HVPPrep,
E1<:DI.GradientPrep,
} <: DI.HessianPrep{SIG}
} <: DI.SparseHessianPrep{SIG}
_sig::Val{SIG}
batch_size_settings::BS
sparsity::P
coloring_result::C
compressed_matrix::M
batched_seeds::S
Expand All @@ -18,11 +20,6 @@ struct SparseHessianPrep{
gradient_prep::E1
end

SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result)
SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result)

## Hessian, one argument

function DI.prepare_hessian_nokwarg(
Expand All @@ -39,13 +36,14 @@ function DI.prepare_hessian_nokwarg(
N = length(column_groups(coloring_result))
batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N)
return _prepare_sparse_hessian_aux(
strict, batch_size_settings, coloring_result, f, backend, x, contexts...
strict, batch_size_settings, sparsity, coloring_result, f, backend, x, contexts...
)
end

function _prepare_sparse_hessian_aux(
strict::Val,
batch_size_settings::DI.BatchSizeSettings{B},
sparsity::AbstractMatrix,
coloring_result::AbstractColoringResult{:symmetric,:column},
f::F,
backend::AutoSparse,
Expand All @@ -68,9 +66,10 @@ function _prepare_sparse_hessian_aux(
gradient_prep = DI.prepare_gradient_nokwarg(
strict, f, DI.inner(dense_backend), x, contexts...
)
return SparseHessianPrep(
return SMCSparseHessianPrep(
_sig,
batch_size_settings,
sparsity,
coloring_result,
compressed_matrix,
batched_seeds,
Expand All @@ -83,7 +82,7 @@ end
function DI.hessian!(
f::F,
hess,
prep::SparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
prep::SMCSparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
backend::AutoSparse,
x,
contexts::Vararg{DI.Context,C},
Expand Down Expand Up @@ -128,7 +127,7 @@ function DI.hessian!(
end

function DI.hessian(
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
hess = similar(sparsity_pattern(prep), eltype(x))
Expand All @@ -139,7 +138,7 @@ function DI.value_gradient_and_hessian!(
f::F,
grad,
hess,
prep::SparseHessianPrep,
prep::SMCSparseHessianPrep,
backend::AutoSparse,
x,
contexts::Vararg{DI.Context,C},
Expand All @@ -153,7 +152,7 @@ function DI.value_gradient_and_hessian!(
end

function DI.value_gradient_and_hessian(
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
y, grad = DI.value_and_gradient(
Expand Down
Loading
Loading