Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,28 @@ function _forward_eval(
tmp_dot += v1 * v2
end
@s f.forward_storage[k] = tmp_dot
elseif node.index == 12 # hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
@j f.partials_storage[ix1] = one(T)
val = @j f.forward_storage[ix1]
@j f.forward_storage[k] = val
end
for j in _eachindex(f.sizes, ix2)
@j f.partials_storage[ix2] = one(T)
val = @j f.forward_storage[ix2]
_setindex!(
f.forward_storage,
val,
f.sizes,
k,
j + nb_cols1 * col_size,
)
end
elseif node.index == 14 # norm
ix = children_arr[children_indices[1]]
tmp_norm_squared = zero(T)
Expand Down Expand Up @@ -339,6 +361,18 @@ function _forward_eval(
f.partials_storage[rhs] = zero(T)
end
end
# This function is written assuming that the final output is scalar.
# Therefore cannot return the matrix, so I guess I return it's first entry only,
# as long as sum or matx-vect products are not implemented.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we should throw an error in this case, but let's do a separate PR for that

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#15

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. In principle though the functions we define take scalar values by definition, just as in the current version in ReverseAD.


#println("Last node ", f.nodes[1].index)
#if f.nodes[1].index == 12
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this if doing ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it should have been deleted. It will be soon, thanks.
(I was checking that hcat vectors of length more than one was fine too. I think it was ok but was constructing a matrix and I had no further operation implemented to deal with it. I'll check this case after allowing norm of a matrix.)

# mtx = reshape(
# f.forward_storage[_storage_range(f.sizes, 1)],
# f.sizes.size[1:f.sizes.ndims[1]]...,
# )
# return mtx
#end
return f.forward_storage[1]
end

Expand Down Expand Up @@ -395,6 +429,50 @@ function _reverse_eval(f::_SubexpressionStorage)
end
end
continue
elseif op == :hcat
idx1, idx2 = children_indices
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
nb_cols1 =
f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2)
col_size =
f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1)
for j in _eachindex(f.sizes, ix1)
partial = @j f.partials_storage[ix1]
val = ifelse(
_getindex(f.reverse_storage, f.sizes, k, j) ==
0.0 && !isfinite(partial),
_getindex(f.reverse_storage, f.sizes, k, j),
_getindex(f.reverse_storage, f.sizes, k, j) *
partial,
)
@j f.reverse_storage[ix1] = val
end
for j in _eachindex(f.sizes, ix2)
partial = @j f.partials_storage[ix2]
val = ifelse(
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) == 0.0 && !isfinite(partial),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
),
_getindex(
f.reverse_storage,
f.sizes,
k,
j + nb_cols1 * col_size,
) * partial,
)
@j f.reverse_storage[ix2] = val
end
continue
elseif op == :norm
# Node `k` is scalar, the jacobian w.r.t. the vectorized input
# child is a row vector whose entries are stored in `f.partials_storage`
Expand All @@ -408,7 +486,7 @@ function _reverse_eval(f::_SubexpressionStorage)
rev_parent,
rev_parent * partial,
)
@j f.reverse_storage[ix] = val
@j f.reverse_storage[ix] = val
end
continue
end
Expand Down
19 changes: 19 additions & 0 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ function _infer_sizes(
elseif op == :+ || op == :-
# TODO assert all arguments have same size
_copy_size!(sizes, k, children_arr[first(children_indices)])
elseif op == :hcat
total_cols = 0
for c_idx in children_indices
total_cols +=
sizes.ndims[children_arr[c_idx]] <= 1 ? 1 :
_size(sizes, children_arr[c_idx], 2)
end
if sizes.ndims[children_arr[first(children_indices)]] == 0
shape = (1, total_cols)
else
@assert sizes.ndims[children_arr[first(
children_indices,
)]] <= 2 "Hcat with ndims > 2 is not supported yet"
shape = (
_size(sizes, children_arr[first(children_indices)], 1),
total_cols,
)
end
_add_size!(sizes, k, tuple(shape...))
elseif op == :*
# TODO assert compatible sizes and all ndims should be 0 or 2
first_matrix = findfirst(children_indices) do i
Expand Down
57 changes: 56 additions & 1 deletion test/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,61 @@ function test_objective_dot_bivariate()
return
end

function test_objective_hcat_0dim()
model = Nonlinear.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
x4 = MOI.VariableIndex(4)
Nonlinear.set_objective(model, :(dot([$x1 $x3], [$x2 $x4])))
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 2, 0, 0, 2, 0, 0]
@test sizes.size_offset == [0, 2, 0, 0, 0, 0, 0]
@test sizes.size == [1, 2, 1, 2]
@test sizes.storage_offset == [0, 1, 3, 4, 5, 7, 8, 9]
x1 = 1.0
x2 = 2.0
x3 = 3.0
x4 = 4.0
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
g = ones(4)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
@test g == [2.0, 1.0, 4.0, 3.0]
return
end

function test_objective_hcat_1dim()
model = Nonlinear.Model()
x1 = MOI.VariableIndex(1)
x2 = MOI.VariableIndex(2)
x3 = MOI.VariableIndex(3)
x4 = MOI.VariableIndex(4)
Nonlinear.set_objective(
model,
:(dot(hcat([$x1], [$x3]), hcat([$x2], [$x4]))),
)
evaluator = Nonlinear.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
MOI.initialize(evaluator, [:Grad])
sizes = evaluator.backend.objective.expr.sizes
@test sizes.ndims == [0, 2, 1, 0, 1, 0, 2, 1, 0, 1, 0]
@test sizes.size_offset == [0, 6, 5, 0, 4, 0, 2, 1, 0, 0, 0]
@test sizes.size == [1, 1, 1, 2, 1, 1, 1, 2]
@test sizes.storage_offset == [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13]
x1 = 1.0
x2 = 2.0
x3 = 3.0
x4 = 4.0
println(MOI.eval_objective(evaluator, [x1, x2, x3, x4]))
@test MOI.eval_objective(evaluator, [x1, x2, x3, x4]) == 14.0
g = ones(4)
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3, x4])
@test g == [2.0, 1.0, 4.0, 3.0]
return
end

function test_objective_norm_univariate()
model = Nonlinear.Model()
x = MOI.VariableIndex(1)
Expand Down Expand Up @@ -110,4 +165,4 @@ end

end # module

TestArrayDiff.runtests()
TestArrayDiff.runtests()
Loading