Skip to content

Commit 592e74e

Browse files
Format code (#1324)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 632ef06 commit 592e74e

File tree

3 files changed

+74
-28
lines changed

3 files changed

+74
-28
lines changed

src/Enzyme.jl

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,26 @@ function Enzyme.EnzymeRules.inactive_noinl(::typeof(XLA.addressable_devices), ar
2424
return nothing
2525
end
2626

27-
function Enzyme.EnzymeRules.noalias(::typeof(Base.similar), a::ConcretePJRTArray, ::Type, args...)
27+
function Enzyme.EnzymeRules.noalias(
28+
::typeof(Base.similar), a::ConcretePJRTArray, ::Type, args...
29+
)
2830
return nothing
2931
end
3032

31-
function Enzyme.EnzymeRules.noalias(::typeof(Base.similar), a::ConcreteIFRTArray, ::Type, args...)
33+
function Enzyme.EnzymeRules.noalias(
34+
::typeof(Base.similar), a::ConcreteIFRTArray, ::Type, args...
35+
)
3236
return nothing
3337
end
3438

35-
function Enzyme.EnzymeRules.augmented_primal(config, ofn::Const{typeof(Base.similar)}, ::Type{RT}, uval::Enzyme.Annotation{<:ConcretePJRTArray}, T::Enzyme.Const{<:Type}, args...) where {RT}
39+
function Enzyme.EnzymeRules.augmented_primal(
40+
config,
41+
ofn::Const{typeof(Base.similar)},
42+
::Type{RT},
43+
uval::Enzyme.Annotation{<:ConcretePJRTArray},
44+
T::Enzyme.Const{<:Type},
45+
args...,
46+
) where {RT}
3647
primargs = ntuple(Val(length(args))) do i
3748
Base.@_inline_meta
3849
args[i].val
@@ -43,30 +54,48 @@ function Enzyme.EnzymeRules.augmented_primal(config, ofn::Const{typeof(Base.simi
4354
else
4455
nothing
4556
end
46-
57+
4758
shadow = if EnzymeRules.needs_shadow(config)
4859
if EnzymeRules.width(config) == 1
4960
ConcretePJRTArray(
50-
zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), uval.val.sharding
61+
zeros(T.val, primargs...);
62+
client=XLA.client(uval.val),
63+
device=XLA.device(uval.val),
64+
uval.val.sharding,
5165
)
5266
else
53-
ntuple(Val(EnzymeRules.width(config))) do i
54-
Base.@_inline_meta
55-
ConcretePJRTArray(
56-
zeros(T.val, primargs...); client=XLA.client(uval.val), device=XLA.device(uval.val), uval.val.sharding
57-
)
58-
end
67+
ntuple(Val(EnzymeRules.width(config))) do i
68+
Base.@_inline_meta
69+
ConcretePJRTArray(
70+
zeros(T.val, primargs...);
71+
client=XLA.client(uval.val),
72+
device=XLA.device(uval.val),
73+
uval.val.sharding,
74+
)
75+
end
5976
end
6077
else
6178
nothing
6279
end
6380

64-
return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing)
81+
return EnzymeRules.AugmentedReturn{
82+
EnzymeRules.primal_type(config, RT),EnzymeRules.shadow_type(config, RT),Nothing
83+
}(
84+
primal, shadow, nothing
85+
)
6586
end
6687

67-
function Enzyme.EnzymeRules.reverse(config, ofn::Const{typeof(Base.similar)}, ::Type{RT}, tape, uval::Enzyme.Annotation{<:ConcretePJRTArray}, T::Enzyme.Const{<:Type}, args::Vararg{Enzyme.Annotation, N}) where {RT, N}
68-
ntuple(Val(N+2)) do i
69-
Base.@_inline_meta
70-
nothing
88+
function Enzyme.EnzymeRules.reverse(
89+
config,
90+
ofn::Const{typeof(Base.similar)},
91+
::Type{RT},
92+
tape,
93+
uval::Enzyme.Annotation{<:ConcretePJRTArray},
94+
T::Enzyme.Const{<:Type},
95+
args::Vararg{Enzyme.Annotation,N},
96+
) where {RT,N}
97+
ntuple(Val(N + 2)) do i
98+
Base.@_inline_meta
99+
nothing
71100
end
72-
end
101+
end

test/autodiff.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
1515
@test Enzyme.guess_activity(Reactant.ConcretePJRTArray{Float32}, Enzyme.Reverse) <:
1616
Enzyme.Duplicated
1717

18-
@test Enzyme.guess_activity(Reactant.ConcreteIFRTArray{Float32, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
18+
@test Enzyme.guess_activity(
19+
Reactant.ConcreteIFRTArray{
20+
Float32,2,Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding,Nothing}
21+
},
22+
Enzyme.Reverse,
23+
) <: Enzyme.Duplicated
1924

2025
@test Enzyme.guess_activity(
2126
Reactant.ConcretePJRTNumber{
@@ -27,19 +32,31 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
2732
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32}, Enzyme.Reverse) <:
2833
Enzyme.Duplicated
2934

30-
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
31-
32-
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
33-
34-
@test Enzyme.guess_activity(Reactant.ConcreteIFRTNumber{Float32, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Enzyme.Reverse) <: Enzyme.Duplicated
35+
@test Enzyme.guess_activity(
36+
Reactant.ConcretePJRTNumber{
37+
Float32,1,Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding,Nothing}
38+
},
39+
Enzyme.Reverse,
40+
) <: Enzyme.Duplicated
3541

36-
@test Enzyme.guess_activity(Reactant.ConcreteIFRTNumber{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
42+
@test Enzyme.guess_activity(Reactant.ConcretePJRTNumber{Float32}, Enzyme.Reverse) <:
43+
Enzyme.Duplicated
3744

45+
@test Enzyme.guess_activity(
46+
Reactant.ConcreteIFRTNumber{
47+
Float32,Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding,Nothing}
48+
},
49+
Enzyme.Reverse,
50+
) <: Enzyme.Duplicated
3851

39-
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32, 2}, Enzyme.Reverse) <: Enzyme.Duplicated
52+
@test Enzyme.guess_activity(Reactant.ConcreteIFRTNumber{Float32}, Enzyme.Reverse) <:
53+
Enzyme.Duplicated
4054

41-
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32}, Enzyme.Reverse) <: Enzyme.Duplicated
55+
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32,2}, Enzyme.Reverse) <:
56+
Enzyme.Duplicated
4257

58+
@test Enzyme.guess_activity(Reactant.TracedRArray{Float32}, Enzyme.Reverse) <:
59+
Enzyme.Duplicated
4360

4461
@test Enzyme.guess_activity(Reactant.TracedRNumber{Float32}, Enzyme.Reverse) <:
4562
Enzyme.Duplicated

test/control_flow.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,10 +854,10 @@ end
854854
end
855855

856856
function loop!(h_mat::AbstractMatrix, η_mat::AbstractMatrix, H_mat::AbstractMatrix)
857-
m,n = size(h_mat)
857+
m, n = size(h_mat)
858858
@inbounds @trace for i in 1:m
859859
@trace for j in 1:n
860-
@allowscalar h_mat[i,j] = η_mat[i,j] + H_mat[i,j]
860+
@allowscalar h_mat[i, j] = η_mat[i, j] + H_mat[i, j]
861861
end
862862
end
863863
end

0 commit comments

Comments
 (0)