-
Notifications
You must be signed in to change notification settings - Fork 0
Merge hcat into norm #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
fe207a5
b59f8e6
0e90752
7842545
734758f
7576b51
5fe5264
bc3c924
b4770da
8e32bf0
8983f4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -247,6 +247,28 @@ function _forward_eval( | |
| tmp_dot += v1 * v2 | ||
| end | ||
| @s f.forward_storage[k] = tmp_dot | ||
| elseif node.index == 12 # hcat | ||
| idx1, idx2 = children_indices | ||
| ix1 = children_arr[idx1] | ||
| ix2 = children_arr[idx2] | ||
| nb_cols1 = f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2) | ||
| col_size = f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1) | ||
| for j in _eachindex(f.sizes, ix1) | ||
| @j f.partials_storage[ix1] = one(T) | ||
| val = @j f.forward_storage[ix1] | ||
| @j f.forward_storage[k] = val | ||
| end | ||
| for j in _eachindex(f.sizes, ix2) | ||
| @j f.partials_storage[ix2] = one(T) | ||
| val = @j f.forward_storage[ix2] | ||
| _setindex!( | ||
| f.forward_storage, | ||
| val, | ||
| f.sizes, | ||
| k, | ||
| j + nb_cols1 * col_size, | ||
| ) | ||
| end | ||
| elseif node.index == 14 # norm | ||
| ix = children_arr[children_indices[1]] | ||
| tmp_norm_squared = zero(T) | ||
|
|
@@ -339,6 +361,18 @@ function _forward_eval( | |
| f.partials_storage[rhs] = zero(T) | ||
| end | ||
| end | ||
| # This function is written assuming that the final output is scalar. | ||
| # Therefore cannot return the matrix, so I guess I return it's first entry only, | ||
| # as long as sum or matx-vect products are not implemented. | ||
|
|
||
| #println("Last node ", f.nodes[1].index) | ||
| #if f.nodes[1].index == 12 | ||
|
||
| # mtx = reshape( | ||
| # f.forward_storage[_storage_range(f.sizes, 1)], | ||
| # f.sizes.size[1:f.sizes.ndims[1]]..., | ||
| # ) | ||
| # return mtx | ||
| #end | ||
| return f.forward_storage[1] | ||
| end | ||
|
|
||
|
|
@@ -395,6 +429,50 @@ function _reverse_eval(f::_SubexpressionStorage) | |
| end | ||
| end | ||
| continue | ||
| elseif op == :hcat | ||
| idx1, idx2 = children_indices | ||
| ix1 = children_arr[idx1] | ||
| ix2 = children_arr[idx2] | ||
| nb_cols1 = | ||
| f.sizes.ndims[ix1] <= 1 ? 1 : _size(f.sizes, ix1, 2) | ||
| col_size = | ||
| f.sizes.ndims[ix1] == 0 ? 1 : _size(f.sizes, k, 1) | ||
| for j in _eachindex(f.sizes, ix1) | ||
| partial = @j f.partials_storage[ix1] | ||
| val = ifelse( | ||
| _getindex(f.reverse_storage, f.sizes, k, j) == | ||
| 0.0 && !isfinite(partial), | ||
| _getindex(f.reverse_storage, f.sizes, k, j), | ||
| _getindex(f.reverse_storage, f.sizes, k, j) * | ||
| partial, | ||
| ) | ||
| @j f.reverse_storage[ix1] = val | ||
| end | ||
| for j in _eachindex(f.sizes, ix2) | ||
| partial = @j f.partials_storage[ix2] | ||
| val = ifelse( | ||
| _getindex( | ||
| f.reverse_storage, | ||
| f.sizes, | ||
| k, | ||
| j + nb_cols1 * col_size, | ||
| ) == 0.0 && !isfinite(partial), | ||
| _getindex( | ||
| f.reverse_storage, | ||
| f.sizes, | ||
| k, | ||
| j + nb_cols1 * col_size, | ||
| ), | ||
| _getindex( | ||
| f.reverse_storage, | ||
| f.sizes, | ||
| k, | ||
| j + nb_cols1 * col_size, | ||
| ) * partial, | ||
| ) | ||
| @j f.reverse_storage[ix2] = val | ||
| end | ||
| continue | ||
| elseif op == :norm | ||
| # Node `k` is scalar, the jacobian w.r.t. the vectorized input | ||
| # child is a row vector whose entries are stored in `f.partials_storage` | ||
|
|
@@ -408,7 +486,7 @@ function _reverse_eval(f::_SubexpressionStorage) | |
| rev_parent, | ||
| rev_parent * partial, | ||
| ) | ||
| @j f.reverse_storage[ix] = val | ||
| @j f.reverse_storage[ix] = val | ||
| end | ||
| continue | ||
| end | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, we should throw an error in this case, but let's do a separate PR for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. In principle though the functions we define take scalar values by definition, just as in the current version in ReverseAD.