-
Notifications
You must be signed in to change notification settings - Fork 79
feat: stacked batchduplicated #2418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2418 +/- ##
==========================================
+ Coverage 67.50% 75.22% +7.72%
==========================================
Files 31 56 +25
Lines 12668 16939 +4271
==========================================
+ Hits 8552 12743 +4191
- Misses 4116 4196 +80 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" | ||
| authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"] | ||
| version = "0.8.9" | ||
| version = "0.8.10" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so one caveat here, can you grep through the code for places with BatchDuplicated and check that we add the requisite condition for StackedBatchDuplicated.
Even more ideally we would have some interface function is_batched or something that can be used by either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed
what are these for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Active{Float64} -> Active
| !!! warning | ||
| Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it isn't supported by Enzyme, should it then be defined here? I.e. is there a benefit for this to be in EnzymeCore and could it just live in Reactant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean honestly this would be nice to support from Enzyme proper as well from a UX perspecive (and could make things like jacobian easier since you give it the full matrix)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'd probably need to do a similar shim layer like we do for mixedduplicated creating subpieces from the bigger one/etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On Julia 1.11 this is possible due to the Memory Change.
function tuple_of_vectors(M::Matrix{T}, shape) where {T}
n, m = size(M)
return ntuple(m) do i
vec = Base.wrap(Array, memoryref(M.ref, (i - 1) * n + 1), (n,))
reshape(vec, shape)
end
end
function mul!(Out::AbstractMatrix, J::JacobianOperator, V::AbstractMatrix)
@assert size(Out, 2) == size(V, 2)
out = tuple_of_vectors(Out, size(J.res))
v = tuple_of_vectors(V, size(J.u))
N = length(out)
autodiff(
Forward,
maybe_duplicated(J.f, Val(N)), Const,
BatchDuplicated(J.res, out),
BatchDuplicated(J.u, v),
maybe_duplicated(J.p, Val(N))
)
return nothing
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that would be nice
This is mostly for use in reactant