From 5cfa9a40912f3697a56ab354c6a06ec2d96437b4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 May 2025 15:19:35 -0500 Subject: [PATCH 1/3] feat: stacked batchduplicated --- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 42 ++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index f1521a0a61..3504932559 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.9" +version = "0.8.10" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index efbbcb42e1..97376e131c 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -2,7 +2,7 @@ module EnzymeCore export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal -export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation +export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, StackedBatchDuplicated, StackedBatchDuplicatedNoNeed, Annotation export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc @@ -145,6 +145,24 @@ struct BatchDuplicated{T,N} <: Annotation{T} end end +""" + StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray) + +Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array. +""" +struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} + val::T + dval::T2 + + @inline function StackedBatchDuplicated(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} + if check + @assert size(x) == size(dx)[2:end] + @assert N + 1 == M + end + return new{typeof(x),typeof(dx)}(x, dx) + end +end + struct BatchDuplicatedFunc{T,N,Func} <: Annotation{T} val::T end @@ -172,14 +190,34 @@ struct BatchDuplicatedNoNeed{T,N} <: Annotation{T} new{T1, N}(x, dx) end end + +""" + StackedBatchDuplicatedNoNeed(x::AbstractArray, ∂f_∂xs::AbstractArray) + +Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1 +dimensional array. +""" +struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} + val::T + dval::T2 + + @inline function StackedBatchDuplicatedNoNeed(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} + if check + @assert size(x) == size(dx)[2:end] + @assert N + 1 == M + end + return new{typeof(x),typeof(dx)}(x, dx) + end +end + @inline batch_size(::BatchDuplicated{T,N}) where {T,N} = N +@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, 1) @inline batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N @inline batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N @inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N - """ MixedDuplicated(x, ∂f_∂x) From 435b1bdcedb7bc233c73269d927e6aae0e8930d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 May 2025 15:21:04 -0500 Subject: [PATCH 2/3] fix: batch dim ordering --- lib/EnzymeCore/src/EnzymeCore.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 97376e131c..75eb1c9a29 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -148,7 +148,7 @@ end """ StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray) -Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array. +Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension). """ struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} val::T @@ -156,7 +156,7 @@ struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{ @inline function StackedBatchDuplicated(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} if check - @assert size(x) == size(dx)[2:end] + @assert size(x) == size(dx)[1:end-1] @assert N + 1 == M end return new{typeof(x),typeof(dx)}(x, dx) @@ -195,7 +195,7 @@ end StackedBatchDuplicatedNoNeed(x::AbstractArray, ∂f_∂xs::AbstractArray) Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1 -dimensional array. +dimensional array (last dimension is the batch dimension). """ struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} val::T @@ -203,7 +203,7 @@ struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annot @inline function StackedBatchDuplicatedNoNeed(x::AbstractArray{T,N}, dx::AbstractArray{T,M}, check::Bool=true) where {T,M,N} if check - @assert size(x) == size(dx)[2:end] + @assert size(x) == size(dx)[1:end-1] @assert N + 1 == M end return new{typeof(x),typeof(dx)}(x, dx) @@ -211,9 +211,10 @@ struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annot end @inline batch_size(::BatchDuplicated{T,N}) where {T,N} = N -@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, 1) +@inline batch_size(d::StackedBatchDuplicated) = size(d.dval, ndims(d.dval)) @inline batch_size(::BatchDuplicatedFunc{T,N}) where {T,N} = N @inline batch_size(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N +@inline batch_size(d::StackedBatchDuplicatedNoNeed) = size(d.dval, ndims(d.dval)) @inline batch_size(::Type{BatchDuplicated{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedFunc{T,N}}) where {T,N} = N @inline batch_size(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N From 51453ef7501f059457c9740223790fe5fe54baac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 26 May 2025 15:38:20 -0500 Subject: [PATCH 3/3] feat: adapt + some warnings --- Project.toml | 4 ++-- lib/EnzymeCore/ext/AdaptExt.jl | 6 ++++++ lib/EnzymeCore/src/EnzymeCore.jl | 8 ++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index cf95ef5e1c..76bc3b9410 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.46" +version = "0.13.47" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -38,7 +38,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.9" +EnzymeCore = "0.8.10" Enzyme_jll = "0.0.180" GPUArraysCore = "0.1.6, 0.2" GPUCompiler = "1.3" diff --git a/lib/EnzymeCore/ext/AdaptExt.jl b/lib/EnzymeCore/ext/AdaptExt.jl index 4d62a20675..252d02fde0 100644 --- a/lib/EnzymeCore/ext/AdaptExt.jl +++ b/lib/EnzymeCore/ext/AdaptExt.jl @@ -15,6 +15,12 @@ end function Adapt.adapt_structure(to, x::BatchDuplicatedNoNeed) return BatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) end +function Adapt.adapt_structure(to, x::StackedBatchDuplicated) + return StackedBatchDuplicated(adapt(to, x.val), adapt(to, x.dval)) +end +function Adapt.adapt_structure(to, x::StackedBatchDuplicatedNoNeed) + return StackedBatchDuplicatedNoNeed(adapt(to, x.val), adapt(to, x.dval)) +end function Adapt.adapt_structure(to, x::MixedDuplicated) return MixedDuplicated(adapt(to, x.val), adapt(to, x.dval)) end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 75eb1c9a29..cfeffb5fa6 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -149,6 +149,10 @@ end StackedBatchDuplicated(x::AbstractArray, ∂f_∂xs::AbstractArray) Like [`BatchDuplicated`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension). + +!!! warning + + Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl. """ struct StackedBatchDuplicated{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} val::T @@ -196,6 +200,10 @@ end Like [`BatchDuplicatedNoNeed`](@ref), except the shadows are stacked into a N + 1 dimensional array (last dimension is the batch dimension). + +!!! warning + + Currently this is mostly supported in Reactant.jl, but extensively not in Enzyme.jl. """ struct StackedBatchDuplicatedNoNeed{T<:AbstractArray,T2<:AbstractArray} <: Annotation{T} val::T