Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,10 @@ end

## Jacobian

struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG}
struct FastDifferentiationOneArgJacobianPrep{SIG,Y,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
_sig::Val{SIG}
y_prototype::Y
sparsity::P
jac_exe::E1
jac_exe!::E1!
end
Expand All @@ -376,14 +377,18 @@ function DI.prepare_jacobian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)
y_vec_var = myvec(y_var)
jac_var = if backend isa AutoSparse
sparse_jacobian(y_vec_var, x_vec_var)
if backend isa AutoSparse
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
sparsity = DI.get_pattern(jac_var)
else
jacobian(y_vec_var, x_vec_var)
jac_var = jacobian(y_vec_var, x_vec_var)
sparsity = nothing
end
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
return FastDifferentiationOneArgJacobianPrep(_sig, y_prototype, jac_exe, jac_exe!)
return FastDifferentiationOneArgJacobianPrep(
_sig, y_prototype, sparsity, jac_exe, jac_exe!
)
end

function DI.jacobian(
Expand Down Expand Up @@ -626,9 +631,10 @@ end

## Hessian

struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
struct FastDifferentiationHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG}
_sig::Val{SIG}
gradient_prep::G
sparsity::P
hess_exe::E2
hess_exe!::E2!
end
Expand All @@ -648,18 +654,22 @@ function DI.prepare_hessian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)

hess_var = if backend isa AutoSparse
sparse_hessian(y_var, x_vec_var)
if backend isa AutoSparse
hess_var = sparse_hessian(y_var, x_vec_var)
sparsity = DI.get_pattern(hess_var)
else
hessian(y_var, x_vec_var)
hess_var = hessian(y_var, x_vec_var)
sparsity = nothing
end
hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false)
hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true)

gradient_prep = DI.prepare_gradient_nokwarg(
strict, f, dense_ad(backend), x, contexts...
)
return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!)
return FastDifferentiationHessianPrep(
_sig, gradient_prep, sparsity, hess_exe, hess_exe!
)
end

function DI.hessian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ end

## Jacobian

struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
struct FastDifferentiationTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
_sig::Val{SIG}
sparsity::P
jac_exe::E1
jac_exe!::E1!
end
Expand All @@ -312,14 +313,16 @@ function DI.prepare_jacobian_nokwarg(
x_vec_var = myvec(x_var)
context_vec_vars = map(myvec, context_vars)
y_vec_var = myvec(y_var)
jac_var = if backend isa AutoSparse
sparse_jacobian(y_vec_var, x_vec_var)
if backend isa AutoSparse
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
sparsity = DI.get_pattern(jac_var)
else
jacobian(y_vec_var, x_vec_var)
jac_var = jacobian(y_vec_var, x_vec_var)
sparsity = nothing
end
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
return FastDifferentiationTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
return FastDifferentiationTwoArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!)
end

function DI.value_and_jacobian(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ module DifferentiationInterfaceSparseArraysExt
using ADTypes: ADTypes
using DifferentiationInterface
import DifferentiationInterface as DI
using SparseArrays: sparse
using SparseArrays: SparseMatrixCSC, sparse, nonzeros

function DI.get_pattern(M::SparseMatrixCSC)
S = similar(M, Bool)
nonzeros(S) .= true
return S
end

include("sparsity_detector.jl")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@ using SparseMatrixColorings:
decompress!
import SparseMatrixColorings as SMC

abstract type SparseJacobianPrep{SIG} <: DI.JacobianPrep{SIG} end
## SMC overloads

SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result)
SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result)
SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result)
SMC.ncolors(prep::SparseJacobianPrep) = ncolors(prep.coloring_result)
abstract type SMCSparseJacobianPrep{SIG} <: DI.SparseJacobianPrep{SIG} end

SMC.sparsity_pattern(prep::DI.SparseJacobianPrep) = prep.sparsity
SMC.column_colors(prep::DI.SparseJacobianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::DI.SparseJacobianPrep) = column_groups(prep.coloring_result)
SMC.row_colors(prep::DI.SparseJacobianPrep) = row_colors(prep.coloring_result)
SMC.row_groups(prep::DI.SparseJacobianPrep) = row_groups(prep.coloring_result)
SMC.ncolors(prep::DI.SparseJacobianPrep) = ncolors(prep.coloring_result)

SMC.sparsity_pattern(prep::DI.SparseHessianPrep) = prep.sparsity
SMC.column_colors(prep::DI.SparseHessianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::DI.SparseHessianPrep) = column_groups(prep.coloring_result)
SMC.ncolors(prep::DI.SparseHessianPrep) = ncolors(prep.coloring_result)

include("jacobian.jl")
include("jacobian_mixed.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
struct SparseHessianPrep{
struct SMCSparseHessianPrep{
SIG,
BS<:DI.BatchSizeSettings,
P<:AbstractMatrix,
C<:AbstractColoringResult{:symmetric,:column},
M<:AbstractMatrix{<:Number},
S<:AbstractVector{<:NTuple},
R<:AbstractVector{<:NTuple},
E2<:DI.HVPPrep,
E1<:DI.GradientPrep,
} <: DI.HessianPrep{SIG}
} <: DI.SparseHessianPrep{SIG}
_sig::Val{SIG}
batch_size_settings::BS
sparsity::P
coloring_result::C
compressed_matrix::M
batched_seeds::S
Expand All @@ -18,11 +20,6 @@ struct SparseHessianPrep{
gradient_prep::E1
end

SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result)
SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result)
SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result)

## Hessian, one argument

function DI.prepare_hessian_nokwarg(
Expand All @@ -39,13 +36,14 @@ function DI.prepare_hessian_nokwarg(
N = length(column_groups(coloring_result))
batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N)
return _prepare_sparse_hessian_aux(
strict, batch_size_settings, coloring_result, f, backend, x, contexts...
strict, batch_size_settings, sparsity, coloring_result, f, backend, x, contexts...
)
end

function _prepare_sparse_hessian_aux(
strict::Val,
batch_size_settings::DI.BatchSizeSettings{B},
sparsity::AbstractMatrix,
coloring_result::AbstractColoringResult{:symmetric,:column},
f::F,
backend::AutoSparse,
Expand All @@ -68,9 +66,10 @@ function _prepare_sparse_hessian_aux(
gradient_prep = DI.prepare_gradient_nokwarg(
strict, f, DI.inner(dense_backend), x, contexts...
)
return SparseHessianPrep(
return SMCSparseHessianPrep(
_sig,
batch_size_settings,
sparsity,
coloring_result,
compressed_matrix,
batched_seeds,
Expand All @@ -83,7 +82,7 @@ end
function DI.hessian!(
f::F,
hess,
prep::SparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
prep::SMCSparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
backend::AutoSparse,
x,
contexts::Vararg{DI.Context,C},
Expand Down Expand Up @@ -128,7 +127,7 @@ function DI.hessian!(
end

function DI.hessian(
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
hess = similar(sparsity_pattern(prep), eltype(x))
Expand All @@ -139,7 +138,7 @@ function DI.value_gradient_and_hessian!(
f::F,
grad,
hess,
prep::SparseHessianPrep,
prep::SMCSparseHessianPrep,
backend::AutoSparse,
x,
contexts::Vararg{DI.Context,C},
Expand All @@ -153,7 +152,7 @@ function DI.value_gradient_and_hessian!(
end

function DI.value_gradient_and_hessian(
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
y, grad = DI.value_and_gradient(
Expand Down
Loading