Skip to content

Commit 2016467

Browse files
committed
Fix formatting
1 parent f0c7b88 commit 2016467

File tree

1 file changed

+14
-24
lines changed

1 file changed

+14
-24
lines changed

src/Nonlinear/ReverseAD/graph_tools.jl

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ end
177177

178178
"""
179179
_get_nonlinear_child_interactions(
180-
nod::Nonlinear.Node,
180+
node::Nonlinear.Node,
181181
num_children::Int,
182182
)
183183
@@ -192,21 +192,20 @@ For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192192
interactions between children.
193193
"""
194194
function _get_nonlinear_child_interactions(
195-
nod::Nonlinear.Node,
195+
node::Nonlinear.Node,
196196
num_children::Int,
197197
)::Vector{Tuple{Int,Int}}
198-
if nod.type == Nonlinear.NODE_CALL_UNIVARIATE
198+
if node.type == Nonlinear.NODE_CALL_UNIVARIATE
199199
@assert num_children == 1
200-
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, nod.index, nothing)
200+
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, node.index, nothing)
201201
# Univariate operators :+ and :- don't create interactions
202202
if op in (:+, :-)
203203
return Tuple{Int,Int}[]
204204
else
205205
return [(1, 1)]
206206
end
207-
elseif nod.type == Nonlinear.NODE_CALL_MULTIVARIATE
208-
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, nod.index, nothing)
209-
207+
elseif node.type == Nonlinear.NODE_CALL_MULTIVARIATE
208+
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, node.index, nothing)
210209
if op in (:+, :-, :ifelse, :min, :max)
211210
# No nonlinear interactions between children
212211
return Tuple{Int,Int}[]
@@ -268,49 +267,39 @@ function _compute_hessian_sparsity(
268267
)
269268
edge_list = Set{Tuple{Int,Int}}()
270269
children_arr = SparseArrays.rowvals(adj)
271-
272270
# Stack entry: (node_index, child_group_index)
273271
stack = Tuple{Int,Int}[]
274272
# Map from child_group_index to variable indices
275273
child_group_variables = Dict{Int,Set{Int}}()
276-
277-
for k in 1:length(nodes)
278-
nod = nodes[k]
279-
@assert nod.type != Nonlinear.NODE_MOI_VARIABLE
280-
274+
for (k, node) in enumerate(nodes)
275+
@assert node.type != Nonlinear.NODE_MOI_VARIABLE
281276
if input_linearity[k] == CONSTANT
282277
continue # No hessian contribution from constant nodes
283278
end
284-
285279
# Check if this node has nonlinear child interactions
286280
children_idx = SparseArrays.nzrange(adj, k)
287281
num_children = length(children_idx)
288-
interactions = _get_nonlinear_child_interactions(nod, num_children)
289-
282+
interactions = _get_nonlinear_child_interactions(node, num_children)
290283
if !isempty(interactions)
291-
# This node has nonlinear child interactions, so collect variables from its children
284+
# This node has nonlinear child interactions, so collect variables
285+
# from its children
292286
empty!(child_group_variables)
293-
294287
# DFS from all children, tracking child index
295288
for (child_position, cidx) in enumerate(children_idx)
296289
child_node_idx = children_arr[cidx]
297290
push!(stack, (child_node_idx, child_position))
298291
end
299-
300292
while length(stack) > 0
301293
r, child_group_idx = pop!(stack)
302-
303294
# Don't traverse into logical conditions or comparisons
304295
if nodes[r].type == Nonlinear.NODE_LOGIC ||
305296
nodes[r].type == Nonlinear.NODE_COMPARISON
306297
continue
307298
end
308-
309299
r_children_idx = SparseArrays.nzrange(adj, r)
310300
for cidx in r_children_idx
311301
push!(stack, (children_arr[cidx], child_group_idx))
312302
end
313-
314303
if nodes[r].type == Nonlinear.NODE_VARIABLE
315304
if !haskey(child_group_variables, child_group_idx)
316305
child_group_variables[child_group_idx] = Set{Int}()
@@ -328,8 +317,8 @@ function _compute_hessian_sparsity(
328317
end
329318
end
330319
_add_hessian_edges!(edge_list, interactions, child_group_variables)
331-
elseif nod.type == Nonlinear.NODE_SUBEXPRESSION
332-
for ij in subexpression_edgelist[nod.index]
320+
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
321+
for ij in subexpression_edgelist[node.index]
333322
push!(edge_list, ij)
334323
end
335324
end
@@ -378,6 +367,7 @@ function _add_hessian_edges!(
378367
end
379368
end
380369
end
370+
return
381371
end
382372

383373
"""

0 commit comments

Comments
 (0)