Skip to content

Commit 12ca833

Browse files
Activity tests (#1316)
* Add copyto override from anytraced rarray * assert size * more info * fix * Activity tests * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * fix * bump * Update test/autodiff.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix * fix --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent f300c2a commit 12ca833

File tree

5 files changed

+111
-3
lines changed

5 files changed

+111
-3
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.113"
4+
version = "0.2.114"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -67,8 +67,8 @@ CEnum = "0.5"
6767
CUDA = "5.6"
6868
Downloads = "1.6"
6969
EnumX = "1"
70-
Enzyme = "0.13.35"
71-
EnzymeCore = "0.8.8"
70+
Enzyme = "0.13.46"
71+
EnzymeCore = "0.8.9"
7272
Functors = "0.5"
7373
GPUArraysCore = "0.2"
7474
GPUCompiler = "1.3"

src/Enzyme.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,58 @@ function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N}
1515
end
1616
return Tuple(results)
1717
end
18+
19+
function Enzyme.EnzymeRules.inactive_noinl(::typeof(XLA.buffer_on_cpu), args...)
20+
return nothing
21+
end
22+
23+
function Enzyme.EnzymeRules.inactive_noinl(::typeof(XLA.addressable_devices), args...)
24+
return nothing
25+
end
26+
27+
function Enzyme.EnzymeRules.noalias(::typeof(Base.similar), a::ConcretePJRTArray, ::Type, args...)
28+
return nothing
29+
end
30+
31+
function Enzyme.EnzymeRules.noalias(::typeof(Base.similar), a::ConcreteIFRTArray, ::Type, args...)
32+
return nothing
33+
end
34+
35+
function Enzyme.EnzymeRules.augmented_primal(config, ofn::Const{typeof(Base.similar)}, ::Type{RT}, uval::Enzyme.Annotation{<:ConcretePJRTArray}, T::Enzyme.Const{<:Type}, args...) where {RT}
36+
primargs = ntuple(Val(length(args))) do i
37+
Base.@_inline_meta
38+
args[i].val
39+
end
40+
41+
primal = if EnzymeRules.needs_primal(config)
42+
ofn.val(uval.val, T.val, primargs...)
43+
else
44+
nothing
45+
end
46+
47+
shadow = if EnzymeRules.needs_shadow(config)
48+
if EnzymeRules.width(config) == 1
49+
ConcretePJRTArray(
50+
zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), uval.val.sharding
51+
)
52+
else
53+
ntuple(Val(EnzymeRules.width(config))) do i
54+
Base.@_inline_meta
55+
ConcretePJRTArray(
56+
zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), uval.val.sharding
57+
)
58+
end
59+
end
60+
else
61+
nothing
62+
end
63+
64+
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
65+
end
66+
67+
function Enzyme.EnzymeRules.reverse(config, ofn::Const{typeof(Base.similar)}, ::Type{RT}, tape, uval::Enzyme.Annotation{<:ConcretePJRTArray}, T::Enzyme.Const{<:Type}, args::Vararg{Enzyme.Annotation, N}) where {RT, N}
68+
ntuple(Val(N+2)) do i
69+
Base.@_inline_meta
70+
nothing
71+
end
72+
end

src/TracedRArray.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,11 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T,
727727
dest.mlir_data = src.mlir_data
728728
return dest
729729
end
730+
731+
function Base.copyto!(dest::TracedRArray, src::AnyTracedRArray)
732+
return copyto!(dest, materialize_traced_array(src))
733+
end
734+
730735
function Base.copyto!(
731736
dest::Reactant.TracedRArray{T},
732737
dstart::Integer,

src/TracedUtils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ function ReactantCore.materialize_traced_array(x::SubArray)
4242
end
4343

4444
function ReactantCore.materialize_traced_array(x::Base.ReshapedArray)
45+
if Base.prod(size(parent(x))) != Base.prod(size(x))
46+
throw(
47+
AssertionError(
48+
"Invalid reshape array, original size $(size(parent(x))) not compatible with new size $(size(x))",
49+
),
50+
)
51+
end
4552
return Ops.reshape(materialize_traced_array(parent(x)), size(x)...)
4653
end
4754

test/autodiff.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,47 @@ square(x) = x * 2
44

55
fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
66

7+
@testset "Activity" begin
8+
@test Enzyme.guess_activity(
9+
Reactant.ConcretePJRTArray{
10+
Float32,2,1,Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding,Nothing}
11+
},
12+
Enzyme.Reverse,
13+
) <: Enzyme.Duplicated
14+
15+
@test Enzyme.guess_activity(Reactant.ConcretePJRTArray{Float32}, Enzyme.Reverse) <:
16+
Enzyme.Duplicated
17+
18+
@test Enzyme.guess_activity(Reactant.ConcreteIFRTArray{Float32, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
19+
20+
@test Enzyme.guess_activity(
21+
Reactant.ConcretePJRTNumber{
22+
Float32,1,Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding,Nothing}
23+
},
24+
Enzyme.Reverse,
25+
) <: Enzyme.Duplicated
26+
27+
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32}, Enzyme.Reverse) <:
28+
Enzyme.Duplicated
29+
30+
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
31+
32+
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
33+
34+
@test Enzyme.guess_activity(Reactant.ConcreteIFRTNumber{Float32, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
35+
36+
@test Enzyme.guess_activity(Reactant.ConcreteIFRTNumber{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
37+
38+
39+
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32, 2}, Enzyme.Reverse) <: Enzyme.Duplicated
40+
41+
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
42+
43+
44+
@test Enzyme.guess_activity(Reactant.TracedRNumber{Float32}, Enzyme.Reverse) <:
45+
Enzyme.Duplicated
46+
end
47+
748
@testset "Basic Forward Mode" begin
849
ores1 = fwd(Forward, Duplicated, ones(3, 2), 3.1 * ones(3, 2))
950
@test typeof(ores1) == NamedTuple{(Symbol("1"),),Tuple{Array{Float64,2}}}

0 commit comments

Comments
 (0)