Skip to content

Commit 1a3c1f9

Browse files
authored
Merge pull request #27 from blegat/sl/broadcasting
Add broadcasted multiplication
2 parents ca1dd1d + 48736cb commit 1a3c1f9

File tree

11 files changed

+477
-229
lines changed

11 files changed

+477
-229
lines changed

src/evaluator.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function eval_univariate_hessian(
143143
x::T,
144144
) where {T}
145145
if id <= registry.univariate_user_operator_start
146-
ret = Nonlinear._eval_univariate_2nd_deriv(id, x)
146+
ret = _eval_univariate_2nd_deriv(id, x)
147147
if ret === nothing
148148
op = registry.univariate_operators[id]
149149
error("Hessian is not defined for operator $op")
@@ -154,3 +154,30 @@ function eval_univariate_hessian(
154154
operator = registry.registered_univariate_operators[offset]
155155
return eval_univariate_hessian(operator, x)
156156
end
157+
158+
"""
159+
adjacency_matrix(nodes::Vector{Node})
160+
161+
Compute the sparse adjacency matrix describing the parent-child relationships in
162+
`nodes`.
163+
164+
The element `(i, j)` is `true` if there is an edge *from* `node[j]` to
165+
`node[i]`. Since we get a column-oriented matrix, this gives us a fast way to
166+
look up the edges leaving any node (that is, the children).
167+
"""
168+
function adjacency_matrix(nodes::Vector{Node})
169+
N = length(nodes)
170+
I, J = Vector{Int}(undef, N), Vector{Int}(undef, N)
171+
numnz = 0
172+
for (i, node) in enumerate(nodes)
173+
if node.parent < 0
174+
continue
175+
end
176+
numnz += 1
177+
I[numnz] = i
178+
J[numnz] = node.parent
179+
end
180+
resize!(I, numnz)
181+
resize!(J, numnz)
182+
return SparseArrays.sparse(I, J, ones(Bool, numnz), N, N)
183+
end

src/forward_over_reverse.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -199,16 +199,16 @@ function _forward_eval_ϵ(
199199
for k in length(ex.nodes):-1:1
200200
node = ex.nodes[k]
201201
partials_storage_ϵ[k] = zero_ϵ
202-
if node.type == Nonlinear.NODE_VARIABLE
202+
if node.type == NODE_VARIABLE
203203
storage_ϵ[k] = x_values_ϵ[node.index]
204-
elseif node.type == Nonlinear.NODE_VALUE
204+
elseif node.type == NODE_VALUE
205205
storage_ϵ[k] = zero_ϵ
206-
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
206+
elseif node.type == NODE_SUBEXPRESSION
207207
storage_ϵ[k] = subexpression_values_ϵ[node.index]
208-
elseif node.type == Nonlinear.NODE_PARAMETER
208+
elseif node.type == NODE_PARAMETER
209209
storage_ϵ[k] = zero_ϵ
210210
else
211-
@assert node.type != Nonlinear.NODE_MOI_VARIABLE
211+
@assert node.type != NODE_MOI_VARIABLE
212212
ϵtmp = zero_ϵ
213213
@inbounds children_idx = SparseArrays.nzrange(ex.adj, k)
214214
for c_idx in children_idx
@@ -223,7 +223,7 @@ function _forward_eval_ϵ(
223223
ϵtmp += storage_val * ex.partials_storage[ix]
224224
end
225225
storage_ϵ[k] = ϵtmp
226-
if node.type == Nonlinear.NODE_CALL_MULTIVARIATE
226+
if node.type == NODE_CALL_MULTIVARIATE
227227
# TODO(odow): consider how to refactor this into Nonlinear.
228228
op = node.index
229229
n_children = length(children_idx)
@@ -349,7 +349,7 @@ function _forward_eval_ϵ(
349349
partials_storage_ϵ[i] = dual
350350
end
351351
end
352-
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
352+
elseif node.type == NODE_CALL_UNIVARIATE
353353
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
354354
f′′ = eval_univariate_hessian(
355355
d.data.operators,
@@ -378,10 +378,10 @@ function _reverse_eval_ϵ(
378378
_reinterpret_unsafe(ForwardDiff.Partials{N,T}, ex.partials_storage_ϵ)
379379
@assert length(reverse_storage_ϵ) >= length(ex.nodes)
380380
@assert length(partials_storage_ϵ) >= length(ex.nodes)
381-
if ex.nodes[1].type == Nonlinear.NODE_VARIABLE
381+
if ex.nodes[1].type == NODE_VARIABLE
382382
@inbounds output_ϵ[ex.nodes[1].index] += scale_ϵ
383383
return
384-
elseif ex.nodes[1].type == Nonlinear.NODE_SUBEXPRESSION
384+
elseif ex.nodes[1].type == NODE_SUBEXPRESSION
385385
@inbounds subexpression_output[ex.nodes[1].index] +=
386386
scale * ex.reverse_storage[1]
387387
@inbounds subexpression_output_ϵ[ex.nodes[1].index] += scale_ϵ
@@ -390,10 +390,10 @@ function _reverse_eval_ϵ(
390390
reverse_storage_ϵ[1] = scale_ϵ
391391
for k in 2:length(ex.nodes)
392392
@inbounds node = ex.nodes[k]
393-
if node.type == Nonlinear.NODE_VALUE ||
394-
node.type == Nonlinear.NODE_LOGIC ||
395-
node.type == Nonlinear.NODE_COMPARISON ||
396-
node.type == Nonlinear.NODE_PARAMETER
393+
if node.type == NODE_VALUE ||
394+
node.type == NODE_LOGIC ||
395+
node.type == NODE_COMPARISON ||
396+
node.type == NODE_PARAMETER
397397
continue
398398
end
399399
parent_value = scale * ex.reverse_storage[node.parent]
@@ -407,9 +407,9 @@ function _reverse_eval_ϵ(
407407
ex.partials_storage[k],
408408
)
409409
end
410-
if node.type == Nonlinear.NODE_VARIABLE
410+
if node.type == NODE_VARIABLE
411411
@inbounds output_ϵ[node.index] += reverse_storage_ϵ[k]
412-
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
412+
elseif node.type == NODE_SUBEXPRESSION
413413
@inbounds subexpression_output[node.index] +=
414414
scale * ex.reverse_storage[k]
415415
@inbounds subexpression_output_ϵ[node.index] += reverse_storage_ϵ[k]

0 commit comments

Comments
 (0)