diff --git a/Project.toml b/Project.toml index cf95ef5e1c..76bc3b9410 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.46" +version = "0.13.47" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -38,7 +38,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.9" +EnzymeCore = "0.8.10" Enzyme_jll = "0.0.180" GPUArraysCore = "0.1.6, 0.2" GPUCompiler = "1.3" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index f1521a0a61..3504932559 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.9" +version = "0.8.10" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl index 4d62a20675..252d02fde0 100644 --- a/lib/EnzymeCore/ext/AdaptExt.jl +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -15,6 +15,12 @@ end function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) end +function Adapt.adapt_structure(to, x::StackedBatchDuplicated) + return StackedBatchDuplicated(adapt(to, x.val), adapt(to, x.dval)) +end +function Adapt.adapt_structure(to, x::StackedBatchDuplicatedNoNeed) + return StackedBatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) +end function Adapt.adapt_structure(to, x::MixedDuplicated) return MixedDuplicated(adapt(to, x.val), adapt(to, x.dval)) end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index efbbcb42e1..cfeffb5fa6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -2,7 +2,7 @@ module EnzymeCore export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation +export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, StackedBatchDuplicated, StackedBatchDuplicatedNoNeed, Annotation export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc @@ -145,6 +145,28 @@ struct BatchDuplicated{T,N} <: Annotation{T} end end +""" + StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray) + +Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension). + +!!! warning + + Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl. +""" +struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} + val::T + dval::T2 + + @inline function StackedBatchDuplicated(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} + if check + @assert size(x) == size(dx)[1:end-1] + @assert N + 1 == M + end + return new{typeof(x),typeof(dx)}(x, dx) + end +end + struct BatchDuplicatedFunc{T,N,Func} <: Annotation{T} val::T end @@ -172,14 +194,39 @@ struct BatchDuplicatedNoNeed{T,N} <: Annotation{T} new{T1, N}(x, dx) end end + +""" + StackedBatchDuplicatedNoNeed(x::AbstractArray, ∂f_∂xs::AbstractArray) + +Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1 +dimensional array (last dimension is the batch dimension). + +!!! warning + + Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl. +""" +struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} + val::T + dval::T2 + + @inline function StackedBatchDuplicatedNoNeed(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} + if check + @assert size(x) == size(dx)[1:end-1] + @assert N + 1 == M + end + return new{typeof(x),typeof(dx)}(x, dx) + end +end + @inline batch_size(::BatchDuplicated{T,N}) where {T,N} = N +@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, ndims(d.dval)) @inline batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N @inline batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N +@inline batch_size(d::StackedBatchDuplicatedNoNeed) = size(d.dval, ndims(d.dval)) @inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N - """ MixedDuplicated(x, ∂f_∂x)