Skip to content

Commit 268b9e4

Browse files
authored
Merge pull request #17 from blegat/sl/hcat
Implement dot and norm on vectors constructed with hcat, row, vcat, vect
2 parents 9f1a085 + f0fad53 commit 268b9e4

File tree

4 files changed

+579
-9
lines changed

4 files changed

+579
-9
lines changed

src/Coloring/Coloring.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function acyclic_coloring(g::UndirectedGraph)
206206
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
207207
color = fill(0, _num_vertices(g))
208208
# disjoint set forest of edges in the graph
209-
S = DataStructures.IntDisjointSets(_num_edges(g))
209+
S = DataStructures.IntDisjointSet{Int}(_num_edges(g))
210210
@inbounds for v in 1:_num_vertices(g)
211211
n_neighbor = _num_neighbors(v, g)
212212
start_neighbor = _start_neighbors(v, g)

src/reverse_mode.jl

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,83 @@ function _forward_eval(
247247
tmp_dot += v1 * v2
248248
end
249249
@s f.forward_storage[k] = tmp_dot
250+
elseif node.index == 12 # hcat
251+
idx1, idx2 = children_indices
252+
ix1 = children_arr[idx1]
253+
ix2 = children_arr[idx2]
254+
nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
255+
col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
256+
for j in _eachindex(f.sizes, ix1)
257+
@j f.partials_storage[ix1] = one(T)
258+
val = @j f.forward_storage[ix1]
259+
@j f.forward_storage[k] = val
260+
end
261+
for j in _eachindex(f.sizes, ix2)
262+
@j f.partials_storage[ix2] = one(T)
263+
val = @j f.forward_storage[ix2]
264+
_setindex!(
265+
f.forward_storage,
266+
val,
267+
f.sizes,
268+
k,
269+
j + nb_cols1 * col_size,
270+
)
271+
end
272+
elseif node.index == 13 # vcat
273+
idx1, idx2 = children_indices
274+
ix1 = children_arr[idx1]
275+
ix2 = children_arr[idx2]
276+
nb_rows1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 1)
277+
nb_rows2 = f.sizes.ndims[ix2] <= 1 ? 1 : _size(f.sizes, ix2, 1)
278+
nb_rows = nb_rows1 + nb_rows2
279+
for j in _eachindex(f.sizes, ix1)
280+
@j f.partials_storage[ix1] = one(T)
281+
val = @j f.forward_storage[ix1]
282+
_setindex!(
283+
f.forward_storage,
284+
val,
285+
f.sizes,
286+
k,
287+
div(j-1, nb_rows1) * nb_rows + 1 + (j-1) % nb_rows1,
288+
)
289+
end
290+
for j in _eachindex(f.sizes, ix2)
291+
@j f.partials_storage[ix2] = one(T)
292+
val = @j f.forward_storage[ix2]
293+
_setindex!(
294+
f.forward_storage,
295+
val,
296+
f.sizes,
297+
k,
298+
div(j-1, nb_rows1) * nb_rows +
299+
1 +
300+
(j-1) % nb_rows1 +
301+
nb_rows1,
302+
)
303+
end
304+
elseif node.index == 14 # norm
305+
ix = children_arr[children_indices[1]]
306+
tmp_norm_squared = zero(T)
307+
for j in _eachindex(f.sizes, ix)
308+
v = @j f.forward_storage[ix]
309+
tmp_norm_squared += v * v
310+
end
311+
@s f.forward_storage[k] = sqrt(tmp_norm_squared)
312+
for j in _eachindex(f.sizes, ix)
313+
v = @j f.forward_storage[ix]
314+
if tmp_norm_squared == 0
315+
@j f.partials_storage[ix] = zero(T)
316+
else
317+
@j f.partials_storage[ix] = v / @s f.forward_storage[k]
318+
end
319+
end
320+
elseif node.index == 16 # row
321+
for j in _eachindex(f.sizes, k)
322+
ix = children_arr[children_indices[j]]
323+
@s f.partials_storage[ix] = one(T)
324+
val = @s f.forward_storage[ix]
325+
@j f.forward_storage[k] = val
326+
end
250327
else # atan, min, max
251328
f_input = _UnsafeVectorView(d.jac_storage, N)
252329
∇f = _UnsafeVectorView(d.user_output_buffer, N)
@@ -380,6 +457,149 @@ function _reverse_eval(f::_SubexpressionStorage)
380457
end
381458
end
382459
continue
460+
elseif op == :hcat
461+
idx1, idx2 = children_indices
462+
ix1 = children_arr[idx1]
463+
ix2 = children_arr[idx2]
464+
nb_cols1 =
465+
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
466+
col_size =
467+
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
468+
for j in _eachindex(f.sizes, ix1)
469+
partial = @j f.partials_storage[ix1]
470+
val = ifelse(
471+
_getindex(f.reverse_storage, f.sizes, k, j) ==
472+
0.0 && !isfinite(partial),
473+
_getindex(f.reverse_storage, f.sizes, k, j),
474+
_getindex(f.reverse_storage, f.sizes, k, j) *
475+
partial,
476+
)
477+
@j f.reverse_storage[ix1] = val
478+
end
479+
for j in _eachindex(f.sizes, ix2)
480+
partial = @j f.partials_storage[ix2]
481+
val = ifelse(
482+
_getindex(
483+
f.reverse_storage,
484+
f.sizes,
485+
k,
486+
j + nb_cols1 * col_size,
487+
) == 0.0 && !isfinite(partial),
488+
_getindex(
489+
f.reverse_storage,
490+
f.sizes,
491+
k,
492+
j + nb_cols1 * col_size,
493+
),
494+
_getindex(
495+
f.reverse_storage,
496+
f.sizes,
497+
k,
498+
j + nb_cols1 * col_size,
499+
) * partial,
500+
)
501+
@j f.reverse_storage[ix2] = val
502+
end
503+
continue
504+
elseif op == :vcat
505+
idx1, idx2 = children_indices
506+
ix1 = children_arr[idx1]
507+
ix2 = children_arr[idx2]
508+
nb_rows1 =
509+
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 1)
510+
nb_rows2 =
511+
f.sizes.ndims[ix2] <= 1 ? 1 : _size(f.sizes, ix2, 1)
512+
nb_rows = nb_rows1 + nb_rows2
513+
row_size =
514+
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 2)
515+
for j in _eachindex(f.sizes, ix1)
516+
partial = @j f.partials_storage[ix1]
517+
val = ifelse(
518+
_getindex(
519+
f.reverse_storage,
520+
f.sizes,
521+
k,
522+
div(j-1, nb_rows1) * nb_rows +
523+
1 +
524+
(j-1) % nb_rows1,
525+
) == 0.0 && !isfinite(partial),
526+
_getindex(
527+
f.reverse_storage,
528+
f.sizes,
529+
k,
530+
div(j-1, nb_rows1) * nb_rows +
531+
1 +
532+
(j-1) % nb_rows1,
533+
),
534+
_getindex(
535+
f.reverse_storage,
536+
f.sizes,
537+
k,
538+
div(j-1, nb_rows1) * nb_rows +
539+
1 +
540+
(j-1) % nb_rows1,
541+
) * partial,
542+
)
543+
@j f.reverse_storage[ix1] = val
544+
end
545+
for j in _eachindex(f.sizes, ix2)
546+
partial = @j f.partials_storage[ix2]
547+
val = ifelse(
548+
_getindex(
549+
f.reverse_storage,
550+
f.sizes,
551+
k,
552+
div(j-1, nb_rows1) * nb_rows +
553+
1 +
554+
(j-1) % nb_rows1 +
555+
nb_rows1,
556+
) == 0.0 && !isfinite(partial),
557+
_getindex(
558+
f.reverse_storage,
559+
f.sizes,
560+
k,
561+
div(j-1, nb_rows1) * nb_rows +
562+
1 +
563+
(j-1) % nb_rows1 +
564+
nb_rows1,
565+
),
566+
_getindex(
567+
f.reverse_storage,
568+
f.sizes,
569+
k,
570+
div(j-1, nb_rows1) * nb_rows +
571+
1 +
572+
(j-1) % nb_rows1 +
573+
nb_rows1,
574+
) * partial,
575+
)
576+
@j f.reverse_storage[ix2] = val
577+
end
578+
continue
579+
elseif op == :norm
580+
# Node `k` is scalar, the jacobian w.r.t. the vectorized input
581+
# child is a row vector whose entries are stored in `f.partials_storage`
582+
rev_parent = @s f.reverse_storage[k]
583+
for j in
584+
_eachindex(f.sizes, children_arr[children_indices[1]])
585+
ix = children_arr[children_indices[1]]
586+
partial = @j f.partials_storage[ix]
587+
val = ifelse(
588+
rev_parent == 0.0 && !isfinite(partial),
589+
rev_parent,
590+
rev_parent * partial,
591+
)
592+
@j f.reverse_storage[ix] = val
593+
end
594+
continue
595+
elseif op == :row
596+
for j in _eachindex(f.sizes, k)
597+
ix = children_arr[children_indices[j]]
598+
rev_parent_j = @j f.reverse_storage[k]
599+
# partial is 1 so we can ignore it
600+
@s f.reverse_storage[ix] = rev_parent_j
601+
end
602+
continue
383603
end
384604
end
385605
elseif node.type != MOI.Nonlinear.NODE_CALL_UNIVARIATE

src/sizes.jl

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,59 @@ function _infer_sizes(
179179
op,
180180
)
181181
_add_size!(sizes, k, (N,))
182+
elseif op == :row
183+
_assert_scalar_children(
184+
sizes,
185+
children_arr,
186+
children_indices,
187+
op,
188+
)
189+
_add_size!(sizes, k, (1, N))
182190
elseif op == :dot
183191
# TODO assert all arguments have same size
192+
elseif op == :norm
193+
# TODO actually norm should be moved to univariate
184194
elseif op == :+ || op == :-
185195
# TODO assert all arguments have same size
186196
_copy_size!(sizes, k, children_arr[first(children_indices)])
197+
elseif op == :hcat
198+
total_cols = 0
199+
for c_idx in children_indices
200+
total_cols +=
201+
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
202+
_size(sizes, children_arr[c_idx], 2)
203+
end
204+
if sizes.ndims[children_arr[first(children_indices)]] == 0
205+
shape = (1, total_cols)
206+
else
207+
@assert sizes.ndims[children_arr[first(
208+
children_indices,
209+
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
210+
shape = (
211+
_size(sizes, children_arr[first(children_indices)], 1),
212+
total_cols,
213+
)
214+
end
215+
_add_size!(sizes, k, tuple(shape...))
216+
elseif op == :vcat
217+
total_rows = 0
218+
for c_idx in children_indices
219+
total_rows +=
220+
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
221+
_size(sizes, children_arr[c_idx], 1)
222+
end
223+
if sizes.ndims[children_arr[first(children_indices)]] == 0
224+
shape = (total_rows, 1)
225+
else
226+
@assert sizes.ndims[children_arr[first(
227+
children_indices,
228+
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
229+
shape = (
230+
total_rows,
231+
_size(sizes, children_arr[first(children_indices)], 2),
232+
)
233+
end
234+
_add_size!(sizes, k, tuple(shape...))
187235
elseif op == :*
188236
# TODO assert compatible sizes and all ndims should be 0 or 2
189237
first_matrix = findfirst(children_indices) do i
@@ -193,14 +241,24 @@ function _infer_sizes(
193241
last_matrix = findfirst(children_indices) do i
194242
return !iszero(sizes.ndims[children_arr[i]])
195243
end
196-
_add_size!(
197-
sizes,
198-
k,
199-
(
200-
_size(sizes, first_matrix, 1),
201-
_size(sizes, last_matrix, sizes.ndims[last_matrix]),
202-
),
203-
)
244+
if sizes.ndims[last_matrix] == 0 ||
245+
sizes.ndims[first_matrix] == 0
246+
_add_size!(sizes, k, (1, 1))
247+
continue
248+
else
249+
_add_size!(
250+
sizes,
251+
k,
252+
(
253+
_size(sizes, first_matrix, 1),
254+
_size(
255+
sizes,
256+
last_matrix,
257+
sizes.ndims[last_matrix],
258+
),
259+
),
260+
)
261+
end
204262
end
205263
elseif op == :^ || op == :/
206264
@assert N == 2

0 commit comments

Comments
 (0)