Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.9"
version = "0.8.10"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so one caveat here, can you grep through the code for places with BatchDuplicated and check that we add the requisite condition for StackedBatchDuplicated.

Even more ideally we would have some interface function is_batched or something that can be used by either

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@inline remove_innerty(::Type{<:BatchDuplicated}) = Duplicated
@inline remove_innerty(::Type{<:BatchDuplicatedNoNeed}) = DuplicatedNoNeed

what are these for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Active{Float64} -> Active


[compat]
Adapt = "3, 4"
Expand Down
43 changes: 41 additions & 2 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, StackedBatchDuplicated, StackedBatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
Expand Down Expand Up @@ -145,6 +145,24 @@
end
end

"""
StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray)

Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension).
"""
struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T}
val::T
dval::T2

@inline function StackedBatchDuplicated(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N}
if check
@assert size(x) == size(dx)[1:end-1]
@assert N + 1 == M

Check warning on line 160 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L157-L160

Added lines #L157 - L160 were not covered by tests
end
return new{typeof(x),typeof(dx)}(x, dx)

Check warning on line 162 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L162

Added line #L162 was not covered by tests
end
end

struct BatchDuplicatedFunc{T,N,Func} <: Annotation{T}
val::T
end
Expand Down Expand Up @@ -172,14 +190,35 @@
new{T1, N}(x, dx)
end
end

"""
StackedBatchDuplicatedNoNeed(x::AbstractArray, ∂f_∂xs::AbstractArray)

Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1
dimensional array (last dimension is the batch dimension).
"""
struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T}
val::T
dval::T2

@inline function StackedBatchDuplicatedNoNeed(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N}
if check
@assert size(x) == size(dx)[1:end-1]
@assert N + 1 == M

Check warning on line 207 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L204-L207

Added lines #L204 - L207 were not covered by tests
end
return new{typeof(x),typeof(dx)}(x, dx)

Check warning on line 209 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L209

Added line #L209 was not covered by tests
end
end

@inline batch_size(::BatchDuplicated{T,N}) where {T,N} = N
@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, ndims(d.dval))

Check warning on line 214 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L214

Added line #L214 was not covered by tests
@inline batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N
@inline batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N
@inline batch_size(d::StackedBatchDuplicatedNoNeed) = size(d.dval, ndims(d.dval))

Check warning on line 217 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L217

Added line #L217 was not covered by tests
@inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N
@inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N


"""
MixedDuplicated(x, ∂f_∂x)

Expand Down
Loading