Skip to content

Commit 99ce3cd

Browse files
committed
Simplify
1 parent ffb58d6 commit 99ce3cd

File tree

6 files changed

+36
-98
lines changed

6 files changed

+36
-98
lines changed

src/coloring.jl

Lines changed: 17 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function _hessian_color_preprocess(
6767
)
6868
tree_result = SMC.coloring(S, problem, algo)
6969
result = ColoringResult(tree_result, Int[])
70-
return I, J, result
70+
return ones(length(local_indices)), I, J, result
7171
end
7272

7373
global_to_local_idx = seen_idx.nzidx # steal for storage
@@ -93,53 +93,15 @@ function _hessian_color_preprocess(
9393
)
9494
tree_result = SMC.coloring(S, problem, algo)
9595

96-
# Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)
97-
# First add all diagonal elements
98-
N = length(local_indices)
99-
100-
# Count off-diagonal elements from tree structure
101-
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
102-
nnz_offdiag = 0
103-
for tree_idx in 1:nt
104-
first = tree_edge_indices[tree_idx]
105-
last = tree_edge_indices[tree_idx+1] - 1
106-
nnz_offdiag += (last - first + 1)
107-
end
108-
109-
I_new = Vector{Int}(undef, N + nnz_offdiag)
110-
J_new = Vector{Int}(undef, N + nnz_offdiag)
111-
k = 0
112-
113-
# Add all diagonal elements
114-
for i in 1:N
115-
k += 1
116-
I_new[k] = local_indices[i]
117-
J_new[k] = local_indices[i]
118-
end
119-
120-
# Then add off-diagonal elements from the tree structure
121-
for tree_idx in 1:nt
122-
first = tree_edge_indices[tree_idx]
123-
last = tree_edge_indices[tree_idx+1] - 1
124-
for pos in first:last
125-
(i_local, j_local) = reverse_bfs_orders[pos]
126-
# Convert from local to global indices and normalize (lower triangle)
127-
i_global = local_indices[i_local]
128-
j_global = local_indices[j_local]
129-
if j_global > i_global
130-
i_global, j_global = j_global, i_global
131-
end
132-
k += 1
133-
I_new[k] = i_global
134-
J_new[k] = j_global
135-
end
136-
end
137-
138-
@assert k == length(I_new)
139-
14096
# Wrap result with local_indices
14197
result = ColoringResult(tree_result, local_indices)
142-
return I_new, J_new, result
98+
99+
# SparseMatrixColorings assumes that `I` and `J` are CSC-ordered
100+
B = SMC.compress(S, tree_result)
101+
C = SMC.decompress(B, tree_result)
102+
I_sorted, J_sorted = SparseArrays.findnz(C)
103+
104+
return C.colptr, I_sorted, J_sorted, result
143105
end
144106

145107
"""
@@ -174,6 +136,7 @@ end
174136

175137
"""
176138
_recover_from_matmat!(
139+
colptr::AbstractVector,
177140
V::AbstractVector{T},
178141
R::AbstractMatrix{T},
179142
result::ColoringResult,
@@ -185,59 +148,20 @@ R is the result of H*R_seed where R_seed is the seed matrix.
185148
`stored_values` is a temporary vector.
186149
"""
187150
function _recover_from_matmat!(
151+
colptr::AbstractVector,
188152
V::AbstractVector{T},
189153
R::AbstractMatrix{T},
190154
result::ColoringResult,
191155
stored_values::AbstractVector{T},
192156
) where {T}
193157
tree_result = result.result
194-
color = SMC.column_colors(tree_result)
195158
N = length(result.local_indices)
196-
# Compute number of off-diagonal nonzeros from the length of V
197-
# V contains N diagonal elements + nnz_offdiag off-diagonal elements
198-
@assert length(stored_values) >= N
199-
200-
# Recover diagonal elements
201-
k = 0
202-
for i in 1:N
203-
k += 1
204-
if color[i] > 0
205-
V[k] = R[i, color[i]]
206-
else
207-
V[k] = zero(T)
208-
end
209-
end
210-
211-
# Recover off-diagonal elements using tree structure
212-
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
213-
fill!(stored_values, zero(T))
214-
215-
for tree_idx in 1:nt
216-
first = tree_edge_indices[tree_idx]
217-
last = tree_edge_indices[tree_idx+1] - 1
218-
219-
# Reset stored_values for vertices in this tree
220-
for pos in first:last
221-
(vertex, _) = reverse_bfs_orders[pos]
222-
stored_values[vertex] = zero(T)
223-
end
224-
(_, root) = reverse_bfs_orders[last]
225-
stored_values[root] = zero(T)
226-
227-
# Recover edge values
228-
for pos in first:last
229-
(i, j) = reverse_bfs_orders[pos]
230-
if color[j] > 0
231-
value = R[i, color[j]] - stored_values[i]
232-
else
233-
value = zero(T)
234-
end
235-
stored_values[j] += value
236-
k += 1
237-
V[k] = value
238-
end
239-
end
240-
241-
@assert k == length(V)
159+
S = _SparseMatrixValuesCSC(
160+
N,
161+
N,
162+
colptr,
163+
V,
164+
)
165+
SMC.decompress!(S, R, tree_result)
242166
return
243167
end

src/forward_over_reverse.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const MAX_CHUNK = 10
2121
f::_FunctionStorage,
2222
H::AbstractVector{Float64},
2323
λ::Float64,
24-
offset::Int,
24+
nzcount::Int,
2525
)::Int
2626
2727
Evaluate the hessian matrix of the function `f` and store the result, scaled by
@@ -65,12 +65,13 @@ function _eval_hessian(
6565
)
6666
end
6767
# TODO(odow): consider reverting to a view.
68+
N = size(ex.seed_matrix, 1)
6869
output_slice = _UnsafeVectorView{Float64}(
6970
nzcount,
7071
length(ex.hess_I),
7172
pointer(H),
7273
)::_UnsafeVectorView{Float64}
73-
_recover_from_matmat!(output_slice, ex.seed_matrix, ex.rinfo, d.output_ϵ)
74+
_recover_from_matmat!(ex.hess_colptr, output_slice, ex.seed_matrix, ex.rinfo, d.output_ϵ)
7475
for i in 1:length(output_slice)
7576
output_slice[i] *= scale
7677
end

src/types.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ struct _FunctionStorage{R<:SMC.AbstractColoringResult}
6767
expr::_SubexpressionStorage
6868
grad_sparsity::Vector{Int}
6969
# Nonzero pattern of Hessian matrix
70+
hess_colptr::Vector{Int}
7071
hess_I::Vector{Int}
7172
hess_J::Vector{Int}
7273
rinfo::Union{Nothing,ColoringResult{R}}
@@ -107,7 +108,7 @@ struct _FunctionStorage{R<:SMC.AbstractColoringResult}
107108
subexpression_edgelist,
108109
subexpression_variables,
109110
)
110-
hess_I, hess_J, rinfo = _hessian_color_preprocess(
111+
hess_colptr, hess_I, hess_J, rinfo = _hessian_color_preprocess(
111112
edgelist,
112113
num_variables,
113114
coloring_algorithm,
@@ -117,6 +118,7 @@ struct _FunctionStorage{R<:SMC.AbstractColoringResult}
117118
return new{R}(
118119
expr,
119120
grad_sparsity,
121+
hess_colptr,
120122
hess_I,
121123
hess_J,
122124
rinfo,

src/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,13 @@ function _reinterpret_unsafe(::Type{T}, x::Vector{R}) where {T,R}
203203
p = reinterpret(Ptr{T}, pointer(x))
204204
return _UnsafeVectorView(0, div(len, sizeof(T)), p)
205205
end
206+
207+
struct _SparseMatrixValuesCSC{Tv,Ti<:Integer,CT<:AbstractVector{Ti},VT<:AbstractVector{Tv}} <: SparseArrays.AbstractSparseMatrixCSC{Tv,Ti}
208+
m::Int # Number of rows
209+
n::Int # Number of columns
210+
colptr::CT # Column i is in colptr[i]:(colptr[i+1]-1)
211+
nzval::VT # Stored values, typically nonzeros
212+
end
213+
214+
Base.size(A::_SparseMatrixValuesCSC) = (A.m, A.n)
215+
SparseArrays.nonzeros(A::_SparseMatrixValuesCSC) = A.nzval

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ ArrayDiff = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
33
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
44
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
55
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
6+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/ReverseAD.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ function test_coloring_end_to_end_hessian_coloring_and_recovery()
484484
ArrayDiff.SMC.GreedyColoringAlgorithm(;
485485
decompression = :substitution,
486486
)
487-
I, J, rinfo = ArrayDiff._hessian_color_preprocess(
487+
colptr, I, J, rinfo = ArrayDiff._hessian_color_preprocess(
488488
Set([(1, 2)]),
489489
2,
490490
coloring_algorithm,
@@ -498,7 +498,7 @@ function test_coloring_end_to_end_hessian_coloring_and_recovery()
498498
hess = [3.4 2.1; 2.1 1.3]
499499
matmat = hess * R
500500
V = zeros(3)
501-
ArrayDiff._recover_from_matmat!(V, matmat, rinfo, zeros(3))
501+
ArrayDiff._recover_from_matmat!(colptr, V, matmat, rinfo, zeros(3))
502502
@test V == [3.4, 1.3, 2.1]
503503
return
504504
end

0 commit comments

Comments
 (0)