Skip to content

Commit a964611

Browse files
committed
refined hessian sparsity detection
1 parent fda0a14 commit a964611

File tree

4 files changed

+170
-121
lines changed

4 files changed

+170
-121
lines changed

src/Nonlinear/ReverseAD/graph_tools.jl

Lines changed: 162 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,76 @@ function _compute_gradient_sparsity!(
175175
return
176176
end
177177

178+
"""
179+
_get_nonlinear_child_interactions(
180+
nod::Nonlinear.Node,
181+
num_children::Int,
182+
)
183+
184+
Get the list of nonlinear child interaction pairs for a node.
185+
Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
186+
that have nonlinear interactions.
187+
188+
For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
189+
and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.
190+
191+
For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192+
interactions between children.
193+
"""
194+
function _get_nonlinear_child_interactions(
195+
nod::Nonlinear.Node,
196+
num_children::Int,
197+
)::Vector{Tuple{Int,Int}}
198+
if nod.type == Nonlinear.NODE_CALL_UNIVARIATE
199+
@assert num_children == 1
200+
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, nod.index, nothing)
201+
# Univariate operators :+ and :- don't create interactions
202+
if op in (:+, :-)
203+
return Tuple{Int,Int}[]
204+
else
205+
return [(1, 1)]
206+
end
207+
elseif nod.type == Nonlinear.NODE_CALL_MULTIVARIATE
208+
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, nod.index, nothing)
209+
210+
if op in (:+, :-, :ifelse, :min, :max)
211+
# No nonlinear interactions between children
212+
return Tuple{Int,Int}[]
213+
elseif op == :*
214+
# All pairs of distinct children interact nonlinearly
215+
result = Tuple{Int,Int}[]
216+
for i in 1:num_children
217+
for j in 1:(i-1)
218+
push!(result, (j, i))
219+
end
220+
end
221+
return result
222+
elseif op == :/
223+
@assert num_children == 2
224+
# The numerator doesn't have a nonlinear interaction with itself.
225+
return [(1, 2), (2, 2)]
226+
else
227+
# Conservative: assume all pairs interact
228+
result = Tuple{Int,Int}[]
229+
for i in 1:num_children
230+
for j in 1:i
231+
push!(result, (j, i))
232+
end
233+
end
234+
return result
235+
end
236+
else
237+
# Logic and comparison nodes don't generate hessian terms.
238+
# Subexpression nodes are special cased.
239+
return Tuple{Int,Int}[]
240+
end
241+
end
242+
178243
"""
179244
_compute_hessian_sparsity(
180245
nodes::Vector{Nonlinear.Node},
181246
adj,
182247
input_linearity::Vector{Linearity},
183-
indexedset::Coloring.IndexedSet,
184248
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
185249
subexpression_variables::Vector{Vector{Int}},
186250
)
@@ -193,142 +257,129 @@ Compute the sparsity pattern the Hessian of an expression.
193257
* `subexpression_variables` is the list of all variables which appear in a
194258
subexpression (including recursively).
195259
196-
Idea: consider the (non)linearity of a node *with respect to the output*. The
197-
children of any node which is nonlinear with respect to the output should have
198-
nonlinear interactions, hence nonzeros in the hessian. This is not true in
199-
general, but holds for everything we consider.
200-
201-
A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
202-
like that. By "nonlinear with respect to the output", we mean that the output
203-
depends nonlinearly on the value of the node, regardless of how the node itself
204-
depends on the input.
260+
Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
205261
"""
206262
function _compute_hessian_sparsity(
207263
nodes::Vector{Nonlinear.Node},
208264
adj,
209265
input_linearity::Vector{Linearity},
210-
indexedset::Coloring.IndexedSet,
211266
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
212267
subexpression_variables::Vector{Vector{Int}},
213268
)
214-
# So start at the root of the tree and classify the linearity wrt the output.
215-
# For each nonlinear node, do a mini DFS and collect the list of children.
216-
# Add a nonlinear interaction between all children of a nonlinear node.
217269
edge_list = Set{Tuple{Int,Int}}()
218-
nonlinear_wrt_output = fill(false, length(nodes))
219270
children_arr = SparseArrays.rowvals(adj)
220-
stack = Int[]
221-
stack_ignore = Bool[]
222-
nonlinear_group = indexedset
223-
if length(nodes) == 1 && nodes[1].type == Nonlinear.NODE_SUBEXPRESSION
224-
# Subexpression comes in linearly, so append edge_list
225-
for ij in subexpression_edgelist[nodes[1].index]
226-
push!(edge_list, ij)
227-
end
228-
end
229-
for k in 2:length(nodes)
271+
272+
# Stack entry: (node_index, child_group_index)
273+
stack = Tuple{Int,Int}[]
274+
# Map from child_group_index to variable indices
275+
child_group_variables = Dict{Int,Set{Int}}()
276+
277+
for k in 1:length(nodes)
230278
nod = nodes[k]
231279
@assert nod.type != Nonlinear.NODE_MOI_VARIABLE
232-
if nonlinear_wrt_output[k]
233-
continue # already seen this node one way or another
234-
elseif input_linearity[k] == CONSTANT
235-
continue # definitely not nonlinear
280+
281+
if input_linearity[k] == CONSTANT
282+
continue # No hessian contribution from constant nodes
236283
end
237-
@assert !nonlinear_wrt_output[nod.parent]
238-
# check if the parent depends nonlinearly on the value of this node
239-
par = nodes[nod.parent]
240-
if par.type == Nonlinear.NODE_CALL_UNIVARIATE
241-
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, par.index, nothing)
242-
if op === nothing || (op != :+ && op != :-)
243-
nonlinear_wrt_output[k] = true
284+
285+
# Check if this node has nonlinear child interactions
286+
children_idx = SparseArrays.nzrange(adj, k)
287+
num_children = length(children_idx)
288+
interactions = _get_nonlinear_child_interactions(nod, num_children)
289+
290+
if !isempty(interactions)
291+
# This node has nonlinear child interactions, so collect variables from its children
292+
empty!(child_group_variables)
293+
294+
# DFS from all children, tracking child index
295+
for (child_position, cidx) in enumerate(children_idx)
296+
child_node_idx = children_arr[cidx]
297+
push!(stack, (child_node_idx, child_position))
244298
end
245-
elseif par.type == Nonlinear.NODE_CALL_MULTIVARIATE
246-
op = get(
247-
Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS,
248-
par.index,
249-
nothing,
250-
)
251-
if op === nothing
252-
nonlinear_wrt_output[k] = true
253-
elseif op in (:+, :-, :ifelse)
254-
# pass
255-
elseif op == :*
256-
# check if all siblings are constant
257-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
258-
if !all(
259-
i ->
260-
input_linearity[children_arr[i]] == CONSTANT ||
261-
children_arr[i] == k,
262-
sibling_idx,
263-
)
264-
# at least one sibling isn't constant
265-
nonlinear_wrt_output[k] = true
299+
300+
while length(stack) > 0
301+
r, child_group_idx = pop!(stack)
302+
303+
# Don't traverse into logical conditions or comparisons
304+
if nodes[r].type == Nonlinear.NODE_LOGIC ||
305+
nodes[r].type == Nonlinear.NODE_COMPARISON
306+
continue
266307
end
267-
elseif op == :/
268-
# check if denominator is nonconstant
269-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
270-
if input_linearity[children_arr[last(sibling_idx)]] != CONSTANT
271-
nonlinear_wrt_output[k] = true
308+
309+
r_children_idx = SparseArrays.nzrange(adj, r)
310+
for cidx in r_children_idx
311+
push!(stack, (children_arr[cidx], child_group_idx))
312+
end
313+
314+
if nodes[r].type == Nonlinear.NODE_VARIABLE
315+
if !haskey(child_group_variables, child_group_idx)
316+
child_group_variables[child_group_idx] = Set{Int}()
317+
end
318+
push!(
319+
child_group_variables[child_group_idx],
320+
nodes[r].index,
321+
)
322+
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
323+
sub_vars = subexpression_variables[nodes[r].index]
324+
if !haskey(child_group_variables, child_group_idx)
325+
child_group_variables[child_group_idx] = Set{Int}()
326+
end
327+
union!(child_group_variables[child_group_idx], sub_vars)
272328
end
273-
else
274-
nonlinear_wrt_output[k] = true
275329
end
276-
end
277-
if nod.type == Nonlinear.NODE_SUBEXPRESSION && !nonlinear_wrt_output[k]
278-
# subexpression comes in linearly, so append edge_list
330+
331+
println(child_group_variables)
332+
_add_hessian_edges!(edge_list, interactions, child_group_variables)
333+
elseif nod.type == Nonlinear.NODE_SUBEXPRESSION
279334
for ij in subexpression_edgelist[nod.index]
280335
push!(edge_list, ij)
281336
end
282337
end
283-
if !nonlinear_wrt_output[k]
284-
continue
285-
end
286-
# do a DFS from here, including all children
287-
@assert isempty(stack)
288-
@assert isempty(stack_ignore)
289-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
290-
for sidx in sibling_idx
291-
push!(stack, children_arr[sidx])
292-
push!(stack_ignore, false)
293-
end
294-
empty!(nonlinear_group)
295-
while length(stack) > 0
296-
r = pop!(stack)
297-
should_ignore = pop!(stack_ignore)
298-
nonlinear_wrt_output[r] = true
299-
if nodes[r].type == Nonlinear.NODE_LOGIC ||
300-
nodes[r].type == Nonlinear.NODE_COMPARISON
301-
# don't count the nonlinear interactions inside
302-
# logical conditions or comparisons
303-
should_ignore = true
304-
end
305-
children_idx = SparseArrays.nzrange(adj, r)
306-
for cidx in children_idx
307-
push!(stack, children_arr[cidx])
308-
push!(stack_ignore, should_ignore)
309-
end
310-
if should_ignore
311-
continue
312-
end
313-
if nodes[r].type == Nonlinear.NODE_VARIABLE
314-
push!(nonlinear_group, nodes[r].index)
315-
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
316-
# append all variables in subexpression
317-
union!(nonlinear_group, subexpression_variables[nodes[r].index])
338+
end
339+
return edge_list
340+
end
341+
342+
"""
343+
_add_hessian_edges!(
344+
edge_list::Set{Tuple{Int,Int}},
345+
interactions::Vector{Tuple{Int,Int}},
346+
child_variables::Dict{Int,Set{Int}},
347+
)
348+
349+
Add hessian edges based on the operator's nonlinear interaction pattern.
350+
"""
351+
function _add_hessian_edges!(
352+
edge_list::Set{Tuple{Int,Int}},
353+
interactions::Vector{Tuple{Int,Int}},
354+
child_variables::Dict{Int,Set{Int}},
355+
)
356+
for (child_i, child_j) in interactions
357+
if child_i == child_j
358+
# Within-child interactions: add all pairs from a single child
359+
if haskey(child_variables, child_i)
360+
vars = child_variables[child_i]
361+
for vi in vars
362+
for vj in vars
363+
i, j = minmax(vi, vj)
364+
push!(edge_list, (j, i))
365+
end
366+
end
318367
end
319-
end
320-
for i_ in 1:nonlinear_group.nnz
321-
i = nonlinear_group.nzidx[i_]
322-
for j_ in 1:nonlinear_group.nnz
323-
j = nonlinear_group.nzidx[j_]
324-
if j > i
325-
continue # Only lower triangle.
368+
else
369+
# Between-child interactions: add pairs from different children
370+
if haskey(child_variables, child_i) &&
371+
haskey(child_variables, child_j)
372+
vars_i = child_variables[child_i]
373+
vars_j = child_variables[child_j]
374+
for vi in vars_i
375+
for vj in vars_j
376+
i, j = minmax(vi, vj)
377+
push!(edge_list, (j, i))
378+
end
326379
end
327-
push!(edge_list, (i, j))
328380
end
329381
end
330382
end
331-
return edge_list
332383
end
333384

334385
"""

src/Nonlinear/ReverseAD/mathoptinterface_api.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
9393
subex.nodes,
9494
subex.adj,
9595
linearity,
96-
coloring_storage,
9796
subexpression_edgelist,
9897
subexpression_variables,
9998
)

src/Nonlinear/ReverseAD/types.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ struct _FunctionStorage
9191
nodes,
9292
adj,
9393
linearity,
94-
coloring_storage,
9594
subexpression_edgelist,
9695
subexpression_variables,
9796
)

test/Nonlinear/ReverseAD.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ function test_linearity()
561561
nodes,
562562
adj,
563563
ret,
564-
indexed_set,
565564
Set{Tuple{Int,Int}}[],
566565
Vector{Int}[],
567566
)
@@ -585,12 +584,7 @@ function test_linearity()
585584
[1, 2],
586585
)
587586
_test_linearity(:(3 * 4 * ($x + $y)), ReverseAD.LINEAR)
588-
_test_linearity(
589-
:($z * $y),
590-
ReverseAD.NONLINEAR,
591-
Set([(3, 2), (3, 3), (2, 2)]),
592-
[2, 3],
593-
)
587+
_test_linearity(:($z * $y), ReverseAD.NONLINEAR, Set([(3, 2)]), [2, 3])
594588
_test_linearity(:(3 + 4), ReverseAD.CONSTANT)
595589
_test_linearity(:(sin(3) + $x), ReverseAD.LINEAR)
596590
_test_linearity(
@@ -635,6 +629,12 @@ function test_linearity()
635629
Set([(1, 1)]),
636630
[1],
637631
)
632+
_test_linearity(
633+
:(($x + $y)/$z),
634+
ReverseAD.NONLINEAR,
635+
Set([(3, 3), (3, 2), (3, 1)]),
636+
[1, 2, 3],
637+
)
638638
return
639639
end
640640

@@ -1416,7 +1416,7 @@ function test_hessian_reinterpret_unsafe()
14161416
x_v = ones(5)
14171417
MOI.eval_hessian_lagrangian(evaluator, H, x_v, 0.0, [1.0, 1.0])
14181418
@test count(isapprox.(H, 1.0; atol = 1e-8)) == 3
1419-
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 6
1419+
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 5
14201420
@test sort(H_s[round.(Bool, H)]) == [(3, 1), (3, 2), (5, 4)]
14211421
return
14221422
end

0 commit comments

Comments
 (0)