@@ -67,7 +67,7 @@ function _hessian_color_preprocess(
6767 )
6868 tree_result = SMC. coloring (S, problem, algo)
6969 result = ColoringResult (tree_result, Int[])
70- return I, J, result
70+ return ones ( length (local_indices)), I, J, result
7171 end
7272
7373 global_to_local_idx = seen_idx. nzidx # steal for storage
@@ -93,53 +93,15 @@ function _hessian_color_preprocess(
9393 )
9494 tree_result = SMC. coloring (S, problem, algo)
9595
96- # Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)
97- # First add all diagonal elements
98- N = length (local_indices)
99-
100- # Count off-diagonal elements from tree structure
101- (; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
102- nnz_offdiag = 0
103- for tree_idx in 1 : nt
104- first = tree_edge_indices[tree_idx]
105- last = tree_edge_indices[tree_idx+ 1 ] - 1
106- nnz_offdiag += (last - first + 1 )
107- end
108-
109- I_new = Vector {Int} (undef, N + nnz_offdiag)
110- J_new = Vector {Int} (undef, N + nnz_offdiag)
111- k = 0
112-
113- # Add all diagonal elements
114- for i in 1 : N
115- k += 1
116- I_new[k] = local_indices[i]
117- J_new[k] = local_indices[i]
118- end
119-
120- # Then add off-diagonal elements from the tree structure
121- for tree_idx in 1 : nt
122- first = tree_edge_indices[tree_idx]
123- last = tree_edge_indices[tree_idx+ 1 ] - 1
124- for pos in first: last
125- (i_local, j_local) = reverse_bfs_orders[pos]
126- # Convert from local to global indices and normalize (lower triangle)
127- i_global = local_indices[i_local]
128- j_global = local_indices[j_local]
129- if j_global > i_global
130- i_global, j_global = j_global, i_global
131- end
132- k += 1
133- I_new[k] = i_global
134- J_new[k] = j_global
135- end
136- end
137-
138- @assert k == length (I_new)
139-
14096 # Wrap result with local_indices
14197 result = ColoringResult (tree_result, local_indices)
142- return I_new, J_new, result
98+
99+ # SparseMatrixColorings assumes that `I` and `J` are CSC-ordered
100+ B = SMC. compress (S, tree_result)
101+ C = SMC. decompress (B, tree_result)
102+ I_sorted, J_sorted = SparseArrays. findnz (C)
103+
104+ return C. colptr, I_sorted, J_sorted, result
143105end
144106
145107"""
174136
175137"""
176138 _recover_from_matmat!(
139+ colptr::AbstractVector,
177140 V::AbstractVector{T},
178141 R::AbstractMatrix{T},
179142 result::ColoringResult,
@@ -185,59 +148,20 @@ R is the result of H*R_seed where R_seed is the seed matrix.
185148`stored_values` is a temporary vector.
186149"""
187150function _recover_from_matmat! (
151+ colptr:: AbstractVector ,
188152 V:: AbstractVector{T} ,
189153 R:: AbstractMatrix{T} ,
190154 result:: ColoringResult ,
191155 stored_values:: AbstractVector{T} ,
192156) where {T}
193157 tree_result = result. result
194- color = SMC. column_colors (tree_result)
195158 N = length (result. local_indices)
196- # Compute number of off-diagonal nonzeros from the length of V
197- # V contains N diagonal elements + nnz_offdiag off-diagonal elements
198- @assert length (stored_values) >= N
199-
200- # Recover diagonal elements
201- k = 0
202- for i in 1 : N
203- k += 1
204- if color[i] > 0
205- V[k] = R[i, color[i]]
206- else
207- V[k] = zero (T)
208- end
209- end
210-
211- # Recover off-diagonal elements using tree structure
212- (; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
213- fill! (stored_values, zero (T))
214-
215- for tree_idx in 1 : nt
216- first = tree_edge_indices[tree_idx]
217- last = tree_edge_indices[tree_idx+ 1 ] - 1
218-
219- # Reset stored_values for vertices in this tree
220- for pos in first: last
221- (vertex, _) = reverse_bfs_orders[pos]
222- stored_values[vertex] = zero (T)
223- end
224- (_, root) = reverse_bfs_orders[last]
225- stored_values[root] = zero (T)
226-
227- # Recover edge values
228- for pos in first: last
229- (i, j) = reverse_bfs_orders[pos]
230- if color[j] > 0
231- value = R[i, color[j]] - stored_values[i]
232- else
233- value = zero (T)
234- end
235- stored_values[j] += value
236- k += 1
237- V[k] = value
238- end
239- end
240-
241- @assert k == length (V)
159+ S = _SparseMatrixValuesCSC (
160+ N,
161+ N,
162+ colptr,
163+ V,
164+ )
165+ SMC. decompress! (S, R, tree_result)
242166 return
243167end
0 commit comments