Skip to content

Commit ee8c7df

Browse files
authored
fix: store sparsity pattern for symbolic backends (#764)
* fix: store sparsity pattern for symbolic backends * Fix constructors * Up * Version * Stale import * Fix * Apply suggestions from code review
1 parent 51c56e8 commit ee8c7df

File tree

14 files changed

+195
-80
lines changed

14 files changed

+195
-80
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,10 @@ end
353353

354354
## Jacobian
355355

356-
struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SIG}
356+
struct FastDifferentiationOneArgJacobianPrep{SIG,Y,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
357357
_sig::Val{SIG}
358358
y_prototype::Y
359+
sparsity::P
359360
jac_exe::E1
360361
jac_exe!::E1!
361362
end
@@ -376,14 +377,18 @@ function DI.prepare_jacobian_nokwarg(
376377
x_vec_var = myvec(x_var)
377378
context_vec_vars = map(myvec, context_vars)
378379
y_vec_var = myvec(y_var)
379-
jac_var = if backend isa AutoSparse
380-
sparse_jacobian(y_vec_var, x_vec_var)
380+
if backend isa AutoSparse
381+
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
382+
sparsity = DI.get_pattern(jac_var)
381383
else
382-
jacobian(y_vec_var, x_vec_var)
384+
jac_var = jacobian(y_vec_var, x_vec_var)
385+
sparsity = nothing
383386
end
384387
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
385388
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
386-
return FastDifferentiationOneArgJacobianPrep(_sig, y_prototype, jac_exe, jac_exe!)
389+
return FastDifferentiationOneArgJacobianPrep(
390+
_sig, y_prototype, sparsity, jac_exe, jac_exe!
391+
)
387392
end
388393

389394
function DI.jacobian(
@@ -626,9 +631,10 @@ end
626631

627632
## Hessian
628633

629-
struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
634+
struct FastDifferentiationHessianPrep{SIG,G,P,E2,E2!} <: DI.SparseHessianPrep{SIG}
630635
_sig::Val{SIG}
631636
gradient_prep::G
637+
sparsity::P
632638
hess_exe::E2
633639
hess_exe!::E2!
634640
end
@@ -648,18 +654,22 @@ function DI.prepare_hessian_nokwarg(
648654
x_vec_var = myvec(x_var)
649655
context_vec_vars = map(myvec, context_vars)
650656

651-
hess_var = if backend isa AutoSparse
652-
sparse_hessian(y_var, x_vec_var)
657+
if backend isa AutoSparse
658+
hess_var = sparse_hessian(y_var, x_vec_var)
659+
sparsity = DI.get_pattern(hess_var)
653660
else
654-
hessian(y_var, x_vec_var)
661+
hess_var = hessian(y_var, x_vec_var)
662+
sparsity = nothing
655663
end
656664
hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false)
657665
hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true)
658666

659667
gradient_prep = DI.prepare_gradient_nokwarg(
660668
strict, f, dense_ad(backend), x, contexts...
661669
)
662-
return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!)
670+
return FastDifferentiationHessianPrep(
671+
_sig, gradient_prep, sparsity, hess_exe, hess_exe!
672+
)
663673
end
664674

665675
function DI.hessian(

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,9 @@ end
289289

290290
## Jacobian
291291

292-
struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
292+
struct FastDifferentiationTwoArgJacobianPrep{SIG,P,E1,E1!} <: DI.SparseJacobianPrep{SIG}
293293
_sig::Val{SIG}
294+
sparsity::P
294295
jac_exe::E1
295296
jac_exe!::E1!
296297
end
@@ -312,14 +313,16 @@ function DI.prepare_jacobian_nokwarg(
312313
x_vec_var = myvec(x_var)
313314
context_vec_vars = map(myvec, context_vars)
314315
y_vec_var = myvec(y_var)
315-
jac_var = if backend isa AutoSparse
316-
sparse_jacobian(y_vec_var, x_vec_var)
316+
if backend isa AutoSparse
317+
jac_var = sparse_jacobian(y_vec_var, x_vec_var)
318+
sparsity = DI.get_pattern(jac_var)
317319
else
318-
jacobian(y_vec_var, x_vec_var)
320+
jac_var = jacobian(y_vec_var, x_vec_var)
321+
sparsity = nothing
319322
end
320323
jac_exe = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=false)
321324
jac_exe! = make_function(jac_var, x_vec_var, context_vec_vars...; in_place=true)
322-
return FastDifferentiationTwoArgJacobianPrep(_sig, jac_exe, jac_exe!)
325+
return FastDifferentiationTwoArgJacobianPrep(_sig, sparsity, jac_exe, jac_exe!)
323326
end
324327

325328
function DI.value_and_jacobian(

DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/DifferentiationInterfaceSparseArraysExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ module DifferentiationInterfaceSparseArraysExt
33
using ADTypes: ADTypes
44
using DifferentiationInterface
55
import DifferentiationInterface as DI
6-
using SparseArrays: sparse
6+
using SparseArrays: SparseMatrixCSC, sparse, nonzeros
7+
8+
function DI.get_pattern(M::SparseMatrixCSC)
9+
S = similar(M, Bool)
10+
nonzeros(S) .= true
11+
return S
12+
end
713

814
include("sparsity_detector.jl")
915

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,21 @@ using SparseMatrixColorings:
1515
decompress!
1616
import SparseMatrixColorings as SMC
1717

18-
abstract type SparseJacobianPrep{SIG} <: DI.JacobianPrep{SIG} end
18+
## SMC overloads
1919

20-
SMC.sparsity_pattern(prep::SparseJacobianPrep) = sparsity_pattern(prep.coloring_result)
21-
SMC.column_colors(prep::SparseJacobianPrep) = column_colors(prep.coloring_result)
22-
SMC.column_groups(prep::SparseJacobianPrep) = column_groups(prep.coloring_result)
23-
SMC.row_colors(prep::SparseJacobianPrep) = row_colors(prep.coloring_result)
24-
SMC.row_groups(prep::SparseJacobianPrep) = row_groups(prep.coloring_result)
25-
SMC.ncolors(prep::SparseJacobianPrep) = ncolors(prep.coloring_result)
20+
abstract type SMCSparseJacobianPrep{SIG} <: DI.SparseJacobianPrep{SIG} end
21+
22+
SMC.sparsity_pattern(prep::DI.SparseJacobianPrep) = prep.sparsity
23+
SMC.column_colors(prep::DI.SparseJacobianPrep) = column_colors(prep.coloring_result)
24+
SMC.column_groups(prep::DI.SparseJacobianPrep) = column_groups(prep.coloring_result)
25+
SMC.row_colors(prep::DI.SparseJacobianPrep) = row_colors(prep.coloring_result)
26+
SMC.row_groups(prep::DI.SparseJacobianPrep) = row_groups(prep.coloring_result)
27+
SMC.ncolors(prep::DI.SparseJacobianPrep) = ncolors(prep.coloring_result)
28+
29+
SMC.sparsity_pattern(prep::DI.SparseHessianPrep) = prep.sparsity
30+
SMC.column_colors(prep::DI.SparseHessianPrep) = column_colors(prep.coloring_result)
31+
SMC.column_groups(prep::DI.SparseHessianPrep) = column_groups(prep.coloring_result)
32+
SMC.ncolors(prep::DI.SparseHessianPrep) = ncolors(prep.coloring_result)
2633

2734
include("jacobian.jl")
2835
include("jacobian_mixed.jl")

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
struct SparseHessianPrep{
1+
struct SMCSparseHessianPrep{
22
SIG,
33
BS<:DI.BatchSizeSettings,
4+
P<:AbstractMatrix,
45
C<:AbstractColoringResult{:symmetric,:column},
56
M<:AbstractMatrix{<:Number},
67
S<:AbstractVector{<:NTuple},
78
R<:AbstractVector{<:NTuple},
89
E2<:DI.HVPPrep,
910
E1<:DI.GradientPrep,
10-
} <: DI.HessianPrep{SIG}
11+
} <: DI.SparseHessianPrep{SIG}
1112
_sig::Val{SIG}
1213
batch_size_settings::BS
14+
sparsity::P
1315
coloring_result::C
1416
compressed_matrix::M
1517
batched_seeds::S
@@ -18,11 +20,6 @@ struct SparseHessianPrep{
1820
gradient_prep::E1
1921
end
2022

21-
SMC.sparsity_pattern(prep::SparseHessianPrep) = sparsity_pattern(prep.coloring_result)
22-
SMC.column_colors(prep::SparseHessianPrep) = column_colors(prep.coloring_result)
23-
SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
24-
SMC.ncolors(prep::SparseHessianPrep) = ncolors(prep.coloring_result)
25-
2623
## Hessian, one argument
2724

2825
function DI.prepare_hessian_nokwarg(
@@ -39,13 +36,14 @@ function DI.prepare_hessian_nokwarg(
3936
N = length(column_groups(coloring_result))
4037
batch_size_settings = DI.pick_batchsize(DI.outer(dense_backend), N)
4138
return _prepare_sparse_hessian_aux(
42-
strict, batch_size_settings, coloring_result, f, backend, x, contexts...
39+
strict, batch_size_settings, sparsity, coloring_result, f, backend, x, contexts...
4340
)
4441
end
4542

4643
function _prepare_sparse_hessian_aux(
4744
strict::Val,
4845
batch_size_settings::DI.BatchSizeSettings{B},
46+
sparsity::AbstractMatrix,
4947
coloring_result::AbstractColoringResult{:symmetric,:column},
5048
f::F,
5149
backend::AutoSparse,
@@ -68,9 +66,10 @@ function _prepare_sparse_hessian_aux(
6866
gradient_prep = DI.prepare_gradient_nokwarg(
6967
strict, f, DI.inner(dense_backend), x, contexts...
7068
)
71-
return SparseHessianPrep(
69+
return SMCSparseHessianPrep(
7270
_sig,
7371
batch_size_settings,
72+
sparsity,
7473
coloring_result,
7574
compressed_matrix,
7675
batched_seeds,
@@ -83,7 +82,7 @@ end
8382
function DI.hessian!(
8483
f::F,
8584
hess,
86-
prep::SparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
85+
prep::SMCSparseHessianPrep{SIG,<:DI.BatchSizeSettings{B}},
8786
backend::AutoSparse,
8887
x,
8988
contexts::Vararg{DI.Context,C},
@@ -128,7 +127,7 @@ function DI.hessian!(
128127
end
129128

130129
function DI.hessian(
131-
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
130+
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
132131
) where {F,C}
133132
DI.check_prep(f, prep, backend, x, contexts...)
134133
hess = similar(sparsity_pattern(prep), eltype(x))
@@ -139,7 +138,7 @@ function DI.value_gradient_and_hessian!(
139138
f::F,
140139
grad,
141140
hess,
142-
prep::SparseHessianPrep,
141+
prep::SMCSparseHessianPrep,
143142
backend::AutoSparse,
144143
x,
145144
contexts::Vararg{DI.Context,C},
@@ -153,7 +152,7 @@ function DI.value_gradient_and_hessian!(
153152
end
154153

155154
function DI.value_gradient_and_hessian(
156-
f::F, prep::SparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
155+
f::F, prep::SMCSparseHessianPrep, backend::AutoSparse, x, contexts::Vararg{DI.Context,C}
157156
) where {F,C}
158157
DI.check_prep(f, prep, backend, x, contexts...)
159158
y, grad = DI.value_and_gradient(

0 commit comments

Comments
 (0)