Skip to content

Commit f70139f

Browse files
authored
Symmetric coloring and decompression for Hessians (#272)
* Symmetric Hessian coloring * Bump SMC compat
1 parent b262365 commit f70139f

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ PolyesterForwardDiff = "0.1.1"
6161
ReverseDiff = "1.15.1"
6262
SparseArrays = "<0.0.1,1"
6363
SparseConnectivityTracer = "0.4.2"
64-
SparseMatrixColorings = "0.3.1"
64+
SparseMatrixColorings = "0.3.2"
6565
Symbolics = "5.27.1"
6666
Tapir = "0.2.4"
6767
Tracker = "0.2.33"

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse
3737
using SparseMatrixColorings:
3838
GreedyColoringAlgorithm,
3939
color_groups,
40-
decompress_columns!,
4140
decompress_columns,
41+
decompress_columns!,
42+
decompress_rows,
4243
decompress_rows!,
43-
decompress_rows
44+
decompress_symmetric,
45+
decompress_symmetric!
4446

4547
abstract type Extras end
4648

DifferentiationInterface/src/sparse/hessian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919
function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
2020
initial_sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
2121
sparsity = col_major(initial_sparsity)
22-
colors = column_coloring(sparsity, coloring_algorithm(backend)) # no star coloring
22+
colors = symmetric_coloring(sparsity, coloring_algorithm(backend))
2323
groups = color_groups(colors)
2424
seeds = map(groups) do group
2525
seed = zero(x)
@@ -41,7 +41,7 @@ function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtra
4141
hvp!(f, products[k], backend, x, seeds[k], hvp_extras_same)
4242
copyto!(view(compressed, :, k), vec(products[k]))
4343
end
44-
decompress_columns!(hess, sparsity, compressed, colors)
44+
decompress_symmetric!(hess, sparsity, compressed, colors)
4545
return hess
4646
end
4747

@@ -51,5 +51,5 @@ function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) wher
5151
compressed = stack(eachindex(seeds, products); dims=2) do k
5252
vec(hvp(f, backend, x, seeds[k], hvp_extras_same))
5353
end
54-
return decompress_columns(sparsity, compressed, colors)
54+
return decompress_symmetric(sparsity, compressed, colors)
5555
end

0 commit comments

Comments
 (0)