Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.13.46"
version = "0.13.47"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -38,7 +38,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.9"
EnzymeCore = "0.8.10"
Enzyme_jll = "0.0.180"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.9"
version = "0.8.10"

[compat]
Adapt = "3, 4"
Expand Down
6 changes: 6 additions & 0 deletions lib/EnzymeCore/ext/AdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ end
function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed)
return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))
end
function Adapt.adapt_structure(to, x::StackedBatchDuplicated)
return StackedBatchDuplicated(adapt(to, x.val), adapt(to, x.dval))
end
function Adapt.adapt_structure(to, x::StackedBatchDuplicatedNoNeed)
return StackedBatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval))
end
function Adapt.adapt_structure(to, x::MixedDuplicated)
return MixedDuplicated(adapt(to, x.val), adapt(to, x.dval))
end
Expand Down
51 changes: 49 additions & 2 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, StackedBatchDuplicated, StackedBatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
Expand Down Expand Up @@ -145,6 +145,28 @@
end
end

"""
StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray)

Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension).

!!! warning

Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl.
"""
struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T}
val::T
dval::T2

@inline function StackedBatchDuplicated(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N}
if check
@assert size(x) == size(dx)[1:end-1]
@assert N + 1 == M

Check warning on line 164 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L161-L164

Added lines #L161 - L164 were not covered by tests
end
return new{typeof(x),typeof(dx)}(x, dx)

Check warning on line 166 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L166

Added line #L166 was not covered by tests
end
end

struct BatchDuplicatedFunc{T,N,Func} <: Annotation{T}
val::T
end
Expand Down Expand Up @@ -172,14 +194,39 @@
new{T1, N}(x, dx)
end
end

"""
StackedBatchDuplicatedNoNeed(x::AbstractArray, ∂f_∂xs::AbstractArray)

Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1
dimensional array (last dimension is the batch dimension).

!!! warning

Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl.
"""
struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T}
val::T
dval::T2

@inline function StackedBatchDuplicatedNoNeed(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N}
if check
@assert size(x) == size(dx)[1:end-1]
@assert N + 1 == M

Check warning on line 215 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L212-L215

Added lines #L212 - L215 were not covered by tests
end
return new{typeof(x),typeof(dx)}(x, dx)

Check warning on line 217 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L217

Added line #L217 was not covered by tests
end
end

@inline batch_size(::BatchDuplicated{T,N}) where {T,N} = N
@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, ndims(d.dval))

Check warning on line 222 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L222

Added line #L222 was not covered by tests
@inline batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N
@inline batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N
@inline batch_size(d::StackedBatchDuplicatedNoNeed) = size(d.dval, ndims(d.dval))

Check warning on line 225 in lib/EnzymeCore/src/EnzymeCore.jl

View check run for this annotation

Codecov / codecov/patch

lib/EnzymeCore/src/EnzymeCore.jl#L225

Added line #L225 was not covered by tests
@inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N
@inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N
@inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N


"""
MixedDuplicated(x, ∂f_∂x)

Expand Down
Loading