Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,28 @@ function _forward_eval(
tmp_dot += v1 * v2
end
@s f.forward_storage[k] = tmp_dot
elseif node.index == 12 # hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
@j f.partials_storage[ix1] = one(T)
val = @j f.forward_storage[ix1]
@j f.forward_storage[k] = val
end
for j in _eachindex(f.sizes, ix2)
@j f.partials_storage[ix2] = one(T)
val = @j f.forward_storage[ix2]
_setindex!(
f.forward_storage,
val,
f.sizes,
k,
j + nb_cols1 * col_size,
)
end
elseif node.index == 14 # norm
ix = children_arr[children_indices[1]]
tmp_norm_squared = zero(T)
Expand Down Expand Up @@ -395,6 +417,50 @@ function _reverse_eval(f::_SubexpressionStorage)
end
end
continue
elseif op == :hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 =
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size =
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
partial = @j f.partials_storage[ix1]
val = ifelse(
_getindex(f.reverse_storage, f.sizes, k, j) ==
0.0 && !isfinite(partial),
_getindex(f.reverse_storage, f.sizes, k, j),
_getindex(f.reverse_storage, f.sizes, k, j) *
partial,
)
@j f.reverse_storage[ix1] = val
end
for j in _eachindex(f.sizes, ix2)
partial = @j f.partials_storage[ix2]
val = ifelse(
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) == 0.0 && !isfinite(partial),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) * partial,
)
@j f.reverse_storage[ix2] = val
end
continue
elseif op == :norm
# Node `k` is scalar, the jacobian w.r.t. the vectorized input
# child is a row vector whose entries are stored in `f.partials_storage`
Expand All @@ -408,7 +474,7 @@ function _reverse_eval(f::_SubexpressionStorage)
rev_parent,
rev_parent * partial,
)
@j f.reverse_storage[ix] = val
@j f.reverse_storage[ix] = val
end
continue
end
Expand Down
19 changes: 19 additions & 0 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ function _infer_sizes(
elseif op == :+ || op == :-
# TODO assert all arguments have same size
_copy_size!(sizes, k, children_arr[first(children_indices)])
elseif op == :hcat
total_cols = 0
for c_idx in children_indices
total_cols +=
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
_size(sizes, children_arr[c_idx], 2)
end
if sizes.ndims[children_arr[first(children_indices)]] == 0
shape = (1, total_cols)
else
@assert sizes.ndims[children_arr[first(
children_indices,
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
shape = (
_size(sizes, children_arr[first(children_indices)], 1),
total_cols,
)
end
_add_size!(sizes, k, tuple(shape...))
elseif op == :*
# TODO assert compatible sizes and all ndims should be 0 or 2
first_matrix = findfirst(children_indices) do i
Expand Down
57 changes: 56 additions & 1 deletion test/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,61 @@ function test_objective_dot_bivariate()
return
end

function test_objective_hcat_0dim()
model = Nonlinear.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
x4 = MOI.VariableIndex(4)
Nonlinear.set_objective(model, :(dot([$x1 $x3], [$x2 $x4])))
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 2, 0, 0, 2, 0, 0]
@test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0]
@test sizes.size == [1, 2, 1, 2]
@test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9]
x1 = 1.0
x2 = 2.0
x3 = 3.0
x4 = 4.0
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
g = ones(4)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
@test g == [2.0, 1.0, 4.0, 3.0]
return
end

function test_objective_hcat_1dim()
model = Nonlinear.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
x4 = MOI.VariableIndex(4)
Nonlinear.set_objective(
model,
:(dot(hcat([$x1], [$x3]), hcat([$x2], [$x4]))),
)
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0]
@test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0]
@test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2]
@test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13]
x1 = 1.0
x2 = 2.0
x3 = 3.0
x4 = 4.0
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
g = ones(4)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
@test g == [2.0, 1.0, 4.0, 3.0]
return
end

function test_objective_norm_univariate()
model = Nonlinear.Model()
x = MOI.VariableIndex(1)
Expand Down Expand Up @@ -110,4 +165,4 @@ end

end # module

TestArrayDiff.runtests()
TestArrayDiff.runtests()
Loading