@@ -175,12 +175,76 @@ function _compute_gradient_sparsity!(
175175 return
176176end
177177
178+ """
179+ _get_nonlinear_child_interactions(
180+ nod::Nonlinear.Node,
181+ num_children::Int,
182+ )
183+
184+ Get the list of nonlinear child interaction pairs for a node.
185+ Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
186+ that have nonlinear interactions.
187+
188+ For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
189+ and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.
190+
191+ For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192+ interactions between children.
193+ """
194+ function _get_nonlinear_child_interactions (
195+ nod:: Nonlinear.Node ,
196+ num_children:: Int ,
197+ ):: Vector{Tuple{Int,Int}}
198+ if nod. type == Nonlinear. NODE_CALL_UNIVARIATE
199+ @assert num_children == 1
200+ op = get (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS, nod. index, nothing )
201+ # Univariate operators :+ and :- don't create interactions
202+ if op in (:+ , :- )
203+ return Tuple{Int,Int}[]
204+ else
205+ return [(1 , 1 )]
206+ end
207+ elseif nod. type == Nonlinear. NODE_CALL_MULTIVARIATE
208+ op = get (Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS, nod. index, nothing )
209+
210+ if op in (:+ , :- , :ifelse , :min , :max )
211+ # No nonlinear interactions between children
212+ return Tuple{Int,Int}[]
213+ elseif op == :*
214+ # All pairs of distinct children interact nonlinearly
215+ result = Tuple{Int,Int}[]
216+ for i in 1 : num_children
217+ for j in 1 : (i- 1 )
218+ push! (result, (j, i))
219+ end
220+ end
221+ return result
222+ elseif op == :/
223+ @assert num_children == 2
224+ # The numerator doesn't have a nonlinear interaction with itself.
225+ return [(1 , 2 ), (2 , 2 )]
226+ else
227+ # Conservative: assume all pairs interact
228+ result = Tuple{Int,Int}[]
229+ for i in 1 : num_children
230+ for j in 1 : i
231+ push! (result, (j, i))
232+ end
233+ end
234+ return result
235+ end
236+ else
237+ # Logic and comparison nodes don't generate hessian terms.
238+ # Subexpression nodes are special cased.
239+ return Tuple{Int,Int}[]
240+ end
241+ end
242+
178243"""
179244 _compute_hessian_sparsity(
180245 nodes::Vector{Nonlinear.Node},
181246 adj,
182247 input_linearity::Vector{Linearity},
183- indexedset::Coloring.IndexedSet,
184248 subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
185249 subexpression_variables::Vector{Vector{Int}},
186250 )
@@ -193,142 +257,129 @@ Compute the sparsity pattern the Hessian of an expression.
193257 * `subexpression_variables` is the list of all variables which appear in a
194258 subexpression (including recursively).
195259
196- Idea: consider the (non)linearity of a node *with respect to the output*. The
197- children of any node which is nonlinear with respect to the output should have
198- nonlinear interactions, hence nonzeros in the hessian. This is not true in
199- general, but holds for everything we consider.
200-
201- A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
202- like that. By "nonlinear with respect to the output", we mean that the output
203- depends nonlinearly on the value of the node, regardless of how the node itself
204- depends on the input.
260+ Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
205261"""
206262function _compute_hessian_sparsity (
207263 nodes:: Vector{Nonlinear.Node} ,
208264 adj,
209265 input_linearity:: Vector{Linearity} ,
210- indexedset:: Coloring.IndexedSet ,
211266 subexpression_edgelist:: Vector{Set{Tuple{Int,Int}}} ,
212267 subexpression_variables:: Vector{Vector{Int}} ,
213268)
214- # So start at the root of the tree and classify the linearity wrt the output.
215- # For each nonlinear node, do a mini DFS and collect the list of children.
216- # Add a nonlinear interaction between all children of a nonlinear node.
217269 edge_list = Set {Tuple{Int,Int}} ()
218- nonlinear_wrt_output = fill (false , length (nodes))
219270 children_arr = SparseArrays. rowvals (adj)
220- stack = Int[]
221- stack_ignore = Bool[]
222- nonlinear_group = indexedset
223- if length (nodes) == 1 && nodes[1 ]. type == Nonlinear. NODE_SUBEXPRESSION
224- # Subexpression comes in linearly, so append edge_list
225- for ij in subexpression_edgelist[nodes[1 ]. index]
226- push! (edge_list, ij)
227- end
228- end
229- for k in 2 : length (nodes)
271+
272+ # Stack entry: (node_index, child_group_index)
273+ stack = Tuple{Int,Int}[]
274+ # Map from child_group_index to variable indices
275+ child_group_variables = Dict {Int,Set{Int}} ()
276+
277+ for k in 1 : length (nodes)
230278 nod = nodes[k]
231279 @assert nod. type != Nonlinear. NODE_MOI_VARIABLE
232- if nonlinear_wrt_output[k]
233- continue # already seen this node one way or another
234- elseif input_linearity[k] == CONSTANT
235- continue # definitely not nonlinear
280+
281+ if input_linearity[k] == CONSTANT
282+ continue # No hessian contribution from constant nodes
236283 end
237- @assert ! nonlinear_wrt_output[nod. parent]
238- # check if the parent depends nonlinearly on the value of this node
239- par = nodes[nod. parent]
240- if par. type == Nonlinear. NODE_CALL_UNIVARIATE
241- op = get (Nonlinear. DEFAULT_UNIVARIATE_OPERATORS, par. index, nothing )
242- if op === nothing || (op != :+ && op != :- )
243- nonlinear_wrt_output[k] = true
284+
285+ # Check if this node has nonlinear child interactions
286+ children_idx = SparseArrays. nzrange (adj, k)
287+ num_children = length (children_idx)
288+ interactions = _get_nonlinear_child_interactions (nod, num_children)
289+
290+ if ! isempty (interactions)
291+ # This node has nonlinear child interactions, so collect variables from its children
292+ empty! (child_group_variables)
293+
294+ # DFS from all children, tracking child index
295+ for (child_position, cidx) in enumerate (children_idx)
296+ child_node_idx = children_arr[cidx]
297+ push! (stack, (child_node_idx, child_position))
244298 end
245- elseif par. type == Nonlinear. NODE_CALL_MULTIVARIATE
246- op = get (
247- Nonlinear. DEFAULT_MULTIVARIATE_OPERATORS,
248- par. index,
249- nothing ,
250- )
251- if op === nothing
252- nonlinear_wrt_output[k] = true
253- elseif op in (:+ , :- , :ifelse )
254- # pass
255- elseif op == :*
256- # check if all siblings are constant
257- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
258- if ! all (
259- i ->
260- input_linearity[children_arr[i]] == CONSTANT ||
261- children_arr[i] == k,
262- sibling_idx,
263- )
264- # at least one sibling isn't constant
265- nonlinear_wrt_output[k] = true
299+
300+ while length (stack) > 0
301+ r, child_group_idx = pop! (stack)
302+
303+ # Don't traverse into logical conditions or comparisons
304+ if nodes[r]. type == Nonlinear. NODE_LOGIC ||
305+ nodes[r]. type == Nonlinear. NODE_COMPARISON
306+ continue
266307 end
267- elseif op == :/
268- # check if denominator is nonconstant
269- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
270- if input_linearity[children_arr[last (sibling_idx)]] != CONSTANT
271- nonlinear_wrt_output[k] = true
308+
309+ r_children_idx = SparseArrays. nzrange (adj, r)
310+ for cidx in r_children_idx
311+ push! (stack, (children_arr[cidx], child_group_idx))
312+ end
313+
314+ if nodes[r]. type == Nonlinear. NODE_VARIABLE
315+ if ! haskey (child_group_variables, child_group_idx)
316+ child_group_variables[child_group_idx] = Set {Int} ()
317+ end
318+ push! (
319+ child_group_variables[child_group_idx],
320+ nodes[r]. index,
321+ )
322+ elseif nodes[r]. type == Nonlinear. NODE_SUBEXPRESSION
323+ sub_vars = subexpression_variables[nodes[r]. index]
324+ if ! haskey (child_group_variables, child_group_idx)
325+ child_group_variables[child_group_idx] = Set {Int} ()
326+ end
327+ union! (child_group_variables[child_group_idx], sub_vars)
272328 end
273- else
274- nonlinear_wrt_output[k] = true
275329 end
276- end
277- if nod. type == Nonlinear. NODE_SUBEXPRESSION && ! nonlinear_wrt_output[k]
278- # subexpression comes in linearly, so append edge_list
330+
331+ println (child_group_variables)
332+ _add_hessian_edges! (edge_list, interactions, child_group_variables)
333+ elseif nod. type == Nonlinear. NODE_SUBEXPRESSION
279334 for ij in subexpression_edgelist[nod. index]
280335 push! (edge_list, ij)
281336 end
282337 end
283- if ! nonlinear_wrt_output[k]
284- continue
285- end
286- # do a DFS from here, including all children
287- @assert isempty (stack)
288- @assert isempty (stack_ignore)
289- sibling_idx = SparseArrays. nzrange (adj, nod. parent)
290- for sidx in sibling_idx
291- push! (stack, children_arr[sidx])
292- push! (stack_ignore, false )
293- end
294- empty! (nonlinear_group)
295- while length (stack) > 0
296- r = pop! (stack)
297- should_ignore = pop! (stack_ignore)
298- nonlinear_wrt_output[r] = true
299- if nodes[r]. type == Nonlinear. NODE_LOGIC ||
300- nodes[r]. type == Nonlinear. NODE_COMPARISON
301- # don't count the nonlinear interactions inside
302- # logical conditions or comparisons
303- should_ignore = true
304- end
305- children_idx = SparseArrays. nzrange (adj, r)
306- for cidx in children_idx
307- push! (stack, children_arr[cidx])
308- push! (stack_ignore, should_ignore)
309- end
310- if should_ignore
311- continue
312- end
313- if nodes[r]. type == Nonlinear. NODE_VARIABLE
314- push! (nonlinear_group, nodes[r]. index)
315- elseif nodes[r]. type == Nonlinear. NODE_SUBEXPRESSION
316- # append all variables in subexpression
317- union! (nonlinear_group, subexpression_variables[nodes[r]. index])
338+ end
339+ return edge_list
340+ end
341+
342+ """
343+ _add_hessian_edges!(
344+ edge_list::Set{Tuple{Int,Int}},
345+ interactions::Vector{Tuple{Int,Int}},
346+ child_variables::Dict{Int,Set{Int}},
347+ )
348+
349+ Add hessian edges based on the operator's nonlinear interaction pattern.
350+ """
351+ function _add_hessian_edges! (
352+ edge_list:: Set{Tuple{Int,Int}} ,
353+ interactions:: Vector{Tuple{Int,Int}} ,
354+ child_variables:: Dict{Int,Set{Int}} ,
355+ )
356+ for (child_i, child_j) in interactions
357+ if child_i == child_j
358+ # Within-child interactions: add all pairs from a single child
359+ if haskey (child_variables, child_i)
360+ vars = child_variables[child_i]
361+ for vi in vars
362+ for vj in vars
363+ i, j = minmax (vi, vj)
364+ push! (edge_list, (j, i))
365+ end
366+ end
318367 end
319- end
320- for i_ in 1 : nonlinear_group. nnz
321- i = nonlinear_group. nzidx[i_]
322- for j_ in 1 : nonlinear_group. nnz
323- j = nonlinear_group. nzidx[j_]
324- if j > i
325- continue # Only lower triangle.
368+ else
369+ # Between-child interactions: add pairs from different children
370+ if haskey (child_variables, child_i) &&
371+ haskey (child_variables, child_j)
372+ vars_i = child_variables[child_i]
373+ vars_j = child_variables[child_j]
374+ for vi in vars_i
375+ for vj in vars_j
376+ i, j = minmax (vi, vj)
377+ push! (edge_list, (j, i))
378+ end
326379 end
327- push! (edge_list, (i, j))
328380 end
329381 end
330382 end
331- return edge_list
332383end
333384
334385"""
0 commit comments