Skip to content

Commit da96225

Browse files
authored
Merge pull request #14 from blegat/sl/hcat
Merge hcat into norm
2 parents 65fa6ef + 8983f4d commit da96225

File tree

3 files changed

+142
-2
lines changed

3 files changed

+142
-2
lines changed

src/reverse_mode.jl

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,28 @@ function _forward_eval(
247247
tmp_dot += v1 * v2
248248
end
249249
@s f.forward_storage[k] = tmp_dot
250+
elseif node.index == 12 # hcat
251+
idx1, idx2 = children_indices
252+
ix1 = children_arr[idx1]
253+
ix2 = children_arr[idx2]
254+
nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
255+
col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
256+
for j in _eachindex(f.sizes, ix1)
257+
@j f.partials_storage[ix1] = one(T)
258+
val = @j f.forward_storage[ix1]
259+
@j f.forward_storage[k] = val
260+
end
261+
for j in _eachindex(f.sizes, ix2)
262+
@j f.partials_storage[ix2] = one(T)
263+
val = @j f.forward_storage[ix2]
264+
_setindex!(
265+
f.forward_storage,
266+
val,
267+
f.sizes,
268+
k,
269+
j + nb_cols1 * col_size,
270+
)
271+
end
250272
elseif node.index == 14 # norm
251273
ix = children_arr[children_indices[1]]
252274
tmp_norm_squared = zero(T)
@@ -395,6 +417,50 @@ function _reverse_eval(f::_SubexpressionStorage)
395417
end
396418
end
397419
continue
420+
elseif op == :hcat
421+
idx1, idx2 = children_indices
422+
ix1 = children_arr[idx1]
423+
ix2 = children_arr[idx2]
424+
nb_cols1 =
425+
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
426+
col_size =
427+
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
428+
for j in _eachindex(f.sizes, ix1)
429+
partial = @j f.partials_storage[ix1]
430+
val = ifelse(
431+
_getindex(f.reverse_storage, f.sizes, k, j) ==
432+
0.0 && !isfinite(partial),
433+
_getindex(f.reverse_storage, f.sizes, k, j),
434+
_getindex(f.reverse_storage, f.sizes, k, j) *
435+
partial,
436+
)
437+
@j f.reverse_storage[ix1] = val
438+
end
439+
for j in _eachindex(f.sizes, ix2)
440+
partial = @j f.partials_storage[ix2]
441+
val = ifelse(
442+
_getindex(
443+
f.reverse_storage,
444+
f.sizes,
445+
k,
446+
j + nb_cols1 * col_size,
447+
) == 0.0 && !isfinite(partial),
448+
_getindex(
449+
f.reverse_storage,
450+
f.sizes,
451+
k,
452+
j + nb_cols1 * col_size,
453+
),
454+
_getindex(
455+
f.reverse_storage,
456+
f.sizes,
457+
k,
458+
j + nb_cols1 * col_size,
459+
) * partial,
460+
)
461+
@j f.reverse_storage[ix2] = val
462+
end
463+
continue
398464
elseif op == :norm
399465
# Node `k` is scalar, the jacobian w.r.t. the vectorized input
400466
# child is a row vector whose entries are stored in `f.partials_storage`
@@ -408,7 +474,7 @@ function _reverse_eval(f::_SubexpressionStorage)
408474
rev_parent,
409475
rev_parent * partial,
410476
)
411-
@j f.reverse_storage[ix] = val
477+
@j f.reverse_storage[ix] = val
412478
end
413479
continue
414480
end

src/sizes.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,25 @@ function _infer_sizes(
186186
elseif op == :+ || op == :-
187187
# TODO assert all arguments have same size
188188
_copy_size!(sizes, k, children_arr[first(children_indices)])
189+
elseif op == :hcat
190+
total_cols = 0
191+
for c_idx in children_indices
192+
total_cols +=
193+
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
194+
_size(sizes, children_arr[c_idx], 2)
195+
end
196+
if sizes.ndims[children_arr[first(children_indices)]] == 0
197+
shape = (1, total_cols)
198+
else
199+
@assert sizes.ndims[children_arr[first(
200+
children_indices,
201+
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
202+
shape = (
203+
_size(sizes, children_arr[first(children_indices)], 1),
204+
total_cols,
205+
)
206+
end
207+
_add_size!(sizes, k, tuple(shape...))
189208
elseif op == :*
190209
# TODO assert compatible sizes and all ndims should be 0 or 2
191210
first_matrix = findfirst(children_indices) do i

test/ArrayDiff.jl

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,61 @@ function test_objective_dot_bivariate()
6464
return
6565
end
6666

67+
function test_objective_hcat_0dim()
68+
model = Nonlinear.Model()
69+
x1 = MOI.VariableIndex(1)
70+
x2 = MOI.VariableIndex(2)
71+
x3 = MOI.VariableIndex(3)
72+
x4 = MOI.VariableIndex(4)
73+
Nonlinear.set_objective(model, :(dot([$x1 $x3], [$x2 $x4])))
74+
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
75+
MOI.initialize(evaluator, [:Grad])
76+
sizes = evaluator.backend.objective.expr.sizes
77+
@test sizes.ndims == [0, 2, 0, 0, 2, 0, 0]
78+
@test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0]
79+
@test sizes.size == [1, 2, 1, 2]
80+
@test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9]
81+
x1 = 1.0
82+
x2 = 2.0
83+
x3 = 3.0
84+
x4 = 4.0
85+
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
86+
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
87+
g = ones(4)
88+
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
89+
@test g == [2.0, 1.0, 4.0, 3.0]
90+
return
91+
end
92+
93+
function test_objective_hcat_1dim()
94+
model = Nonlinear.Model()
95+
x1 = MOI.VariableIndex(1)
96+
x2 = MOI.VariableIndex(2)
97+
x3 = MOI.VariableIndex(3)
98+
x4 = MOI.VariableIndex(4)
99+
Nonlinear.set_objective(
100+
model,
101+
:(dot(hcat([$x1], [$x3]), hcat([$x2], [$x4]))),
102+
)
103+
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
104+
MOI.initialize(evaluator, [:Grad])
105+
sizes = evaluator.backend.objective.expr.sizes
106+
@test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0]
107+
@test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0]
108+
@test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2]
109+
@test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13]
110+
x1 = 1.0
111+
x2 = 2.0
112+
x3 = 3.0
113+
x4 = 4.0
114+
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
115+
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
116+
g = ones(4)
117+
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
118+
@test g == [2.0, 1.0, 4.0, 3.0]
119+
return
120+
end
121+
67122
function test_objective_norm_univariate()
68123
model = Nonlinear.Model()
69124
x = MOI.VariableIndex(1)
@@ -110,4 +165,4 @@ end
110165

111166
end # module
112167

113-
TestArrayDiff.runtests()
168+
TestArrayDiff.runtests()

0 commit comments

Comments
 (0)