Skip to content

Commit c592ca9

Browse files
Format Julia code (#1507)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 7bbd7ae commit c592ca9

File tree

6 files changed

+152
-151
lines changed

6 files changed

+152
-151
lines changed

ext/ReactantOffsetArraysExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ function Base.getindex(
4545
end
4646

4747
function Base.getindex(
48-
x::OffsetVector{Reactant.TracedRNumber{T}, Reactant.TracedRArray{T, 1}}, indices::Base.OneTo{Int}
48+
x::OffsetVector{Reactant.TracedRNumber{T},Reactant.TracedRArray{T,1}},
49+
indices::Base.OneTo{Int},
4950
) where {T}
5051
offset_indices = indices .- x.offsets[1]
5152
return getindex(parent(x), offset_indices)

src/Compiler.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2170,8 +2170,10 @@ function compile_mlir!(
21702170
)
21712171
end
21722172
end
2173-
2174-
func_op = MLIR.API.mlirSymbolTableLookup(MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname)
2173+
2174+
func_op = MLIR.API.mlirSymbolTableLookup(
2175+
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname
2176+
)
21752177
@assert func_op.ptr !== C_NULL
21762178
func_op = MLIR.IR.Operation(func_op, false)
21772179
fnbody = MLIR.IR.first_block(MLIR.IR.region(func_op, 1))::MLIR.IR.Block

test/autodiff.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,20 +196,20 @@ end
196196
end
197197

198198
if !contains(string(Reactant.devices()[1]), "TPU")
199-
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
200-
a = ones(ComplexF64, 2, 2)
201-
b = 2.0 * ones(ComplexF64, 2, 2)
202-
a_re = Reactant.to_rarray(a)
203-
b_re = Reactant.to_rarray(b)
204-
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
205-
@test begin
206-
res = @jit df(a_re, b_re) # before, this segfaulted
207-
(res.val 4ones(2, 2)) &&
208-
(res.derivs[1] 4ones(2, 2)) &&
209-
(res.derivs[2] 2ones(2, 2))
199+
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
200+
a = ones(ComplexF64, 2, 2)
201+
b = 2.0 * ones(ComplexF64, 2, 2)
202+
a_re = Reactant.to_rarray(a)
203+
b_re = Reactant.to_rarray(b)
204+
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
205+
@test begin
206+
res = @jit df(a_re, b_re) # before, this segfaulted
207+
(res.val 4ones(2, 2)) &&
208+
(res.derivs[1] 4ones(2, 2)) &&
209+
(res.derivs[2] 2ones(2, 2))
210+
end
210211
end
211212
end
212-
end
213213

214214
@testset "onehot" begin
215215
x = Reactant.to_rarray(rand(3, 4))

test/basic.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -989,13 +989,13 @@ end
989989
end
990990

991991
if !contains(string(Reactant.devices()[1]), "TPU")
992-
@testset "Hlo Cost Analysis" begin
993-
x_ra = Reactant.to_rarray(rand(4, 4))
994-
mul_comp = @compile x_ra * x_ra
995-
cost = Reactant.XLA.cost_analysis(mul_comp)
992+
@testset "Hlo Cost Analysis" begin
993+
x_ra = Reactant.to_rarray(rand(4, 4))
994+
mul_comp = @compile x_ra * x_ra
995+
cost = Reactant.XLA.cost_analysis(mul_comp)
996996

997-
@test cost isa Reactant.XLA.HloCostAnalysisProperties
998-
end
997+
@test cost isa Reactant.XLA.HloCostAnalysisProperties
998+
end
999999
end
10001000

10011001
function fractional_idx(times, t)

test/complex.jl

Lines changed: 101 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2,123 +2,121 @@ using Test
22
using Reactant
33

44
if !contains(string(Reactant.devices()[1]), "TPU")
5-
6-
@testset "conj" begin
7-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
8-
x_concrete = Reactant.to_rarray(x)
9-
@test only(@jit(conj(x_concrete))) == conj(x)
10-
end
11-
12-
@testset "$(typeof(x))" for x in (
13-
fill(1.0 + 2.0im),
14-
fill(1.0),
15-
[1.0 + 2.0im; 3.0 + 4.0im],
16-
[1.0; 3.0],
17-
[1.0 + 2.0im 3.0 + 4.0im],
18-
[1.0 2.0],
19-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
20-
[1.0 3.0; 5.0 7.0],
21-
)
22-
x_concrete = Reactant.to_rarray(x)
23-
@test @jit(conj(x_concrete)) == conj(x)
24-
end
25-
end
26-
27-
@testset "conj!" begin
28-
@testset "$(typeof(x))" for x in (
29-
fill(1.0 + 2.0im),
30-
fill(1.0),
31-
[1.0 + 2.0im; 3.0 + 4.0im],
32-
[1.0; 3.0],
33-
[1.0 + 2.0im 3.0 + 4.0im],
34-
[1.0 2.0],
35-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
36-
[1.0 3.0; 5.0 7.0],
37-
)
38-
x_concrete = Reactant.to_rarray(x)
39-
@test @jit(conj!(x_concrete)) == conj(x)
40-
@test x_concrete == conj(x)
5+
@testset "conj" begin
6+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
7+
x_concrete = Reactant.to_rarray(x)
8+
@test only(@jit(conj(x_concrete))) == conj(x)
9+
end
10+
11+
@testset "$(typeof(x))" for x in (
12+
fill(1.0 + 2.0im),
13+
fill(1.0),
14+
[1.0 + 2.0im; 3.0 + 4.0im],
15+
[1.0; 3.0],
16+
[1.0 + 2.0im 3.0 + 4.0im],
17+
[1.0 2.0],
18+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
19+
[1.0 3.0; 5.0 7.0],
20+
)
21+
x_concrete = Reactant.to_rarray(x)
22+
@test @jit(conj(x_concrete)) == conj(x)
23+
end
4124
end
42-
end
4325

44-
@testset "real" begin
45-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
46-
x_concrete = Reactant.to_rarray(x)
47-
@test only(@jit(real(x_concrete))) == real(x)
26+
@testset "conj!" begin
27+
@testset "$(typeof(x))" for x in (
28+
fill(1.0 + 2.0im),
29+
fill(1.0),
30+
[1.0 + 2.0im; 3.0 + 4.0im],
31+
[1.0; 3.0],
32+
[1.0 + 2.0im 3.0 + 4.0im],
33+
[1.0 2.0],
34+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
35+
[1.0 3.0; 5.0 7.0],
36+
)
37+
x_concrete = Reactant.to_rarray(x)
38+
@test @jit(conj!(x_concrete)) == conj(x)
39+
@test x_concrete == conj(x)
40+
end
4841
end
4942

50-
@testset "$(typeof(x))" for x in (
51-
fill(1.0 + 2.0im),
52-
fill(1.0),
53-
[1.0 + 2.0im; 3.0 + 4.0im],
54-
[1.0; 3.0],
55-
[1.0 + 2.0im 3.0 + 4.0im],
56-
[1.0 2.0],
57-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
58-
[1.0 3.0; 5.0 7.0],
59-
)
60-
x_concrete = Reactant.to_rarray(x)
61-
@test @jit(real(x_concrete)) == real(x)
43+
@testset "real" begin
44+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
45+
x_concrete = Reactant.to_rarray(x)
46+
@test only(@jit(real(x_concrete))) == real(x)
47+
end
48+
49+
@testset "$(typeof(x))" for x in (
50+
fill(1.0 + 2.0im),
51+
fill(1.0),
52+
[1.0 + 2.0im; 3.0 + 4.0im],
53+
[1.0; 3.0],
54+
[1.0 + 2.0im 3.0 + 4.0im],
55+
[1.0 2.0],
56+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
57+
[1.0 3.0; 5.0 7.0],
58+
)
59+
x_concrete = Reactant.to_rarray(x)
60+
@test @jit(real(x_concrete)) == real(x)
61+
end
6262
end
63-
end
6463

65-
@testset "imag" begin
66-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
67-
x_concrete = Reactant.to_rarray(x)
68-
@test only(@jit(imag(x_concrete))) == imag(x)
64+
@testset "imag" begin
65+
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
66+
x_concrete = Reactant.to_rarray(x)
67+
@test only(@jit(imag(x_concrete))) == imag(x)
68+
end
69+
70+
@testset "$(typeof(x))" for x in (
71+
fill(1.0 + 2.0im),
72+
fill(1.0),
73+
[1.0 + 2.0im; 3.0 + 4.0im],
74+
[1.0; 3.0],
75+
[1.0 + 2.0im 3.0 + 4.0im],
76+
[1.0 2.0],
77+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
78+
[1.0 3.0; 5.0 7.0],
79+
)
80+
x_concrete = Reactant.to_rarray(x)
81+
@test @jit(imag(x_concrete)) == imag(x)
82+
end
6983
end
7084

71-
@testset "$(typeof(x))" for x in (
72-
fill(1.0 + 2.0im),
73-
fill(1.0),
74-
[1.0 + 2.0im; 3.0 + 4.0im],
75-
[1.0; 3.0],
76-
[1.0 + 2.0im 3.0 + 4.0im],
77-
[1.0 2.0],
78-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
79-
[1.0 3.0; 5.0 7.0],
80-
)
85+
@testset "abs: $T" for T in (Float32, ComplexF32)
86+
x = randn(T, 10)
8187
x_concrete = Reactant.to_rarray(x)
82-
@test @jit(imag(x_concrete)) == imag(x)
88+
@test @jit(abs.(x_concrete)) abs.(x)
8389
end
84-
end
8590

86-
@testset "abs: $T" for T in (Float32, ComplexF32)
87-
x = randn(T, 10)
88-
x_concrete = Reactant.to_rarray(x)
89-
@test @jit(abs.(x_concrete)) abs.(x)
90-
end
91+
@testset "promote_to Complex" begin
92+
x = 1.0 + 2.0im
93+
y = ConcreteRNumber(x)
9194

92-
@testset "promote_to Complex" begin
93-
x = 1.0 + 2.0im
94-
y = ConcreteRNumber(x)
95+
f = Reactant.compile((y,)) do z
96+
z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im)
97+
end
9598

96-
f = Reactant.compile((y,)) do z
97-
z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im)
99+
@test isapprox(f(y), 2.0 - 1.0im)
98100
end
99101

100-
@test isapprox(f(y), 2.0 - 1.0im)
101-
end
102-
103-
@testset "complex reduction" begin
104-
x = randn(ComplexF32, 10, 10)
105-
x_ra = Reactant.to_rarray(x)
106-
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
107-
end
108-
109-
@testset "create complex numbers" begin
110-
x = randn(ComplexF32)
111-
x_ra = Reactant.to_rarray(x; track_numbers=true)
112-
@test @jit(Complex(x_ra)) == x_ra
113-
114-
x = randn(Float32)
115-
y = randn(Float64)
116-
x_ra = Reactant.to_rarray(x; track_numbers=true)
117-
y_ra = Reactant.to_rarray(y; track_numbers=true)
118-
@test @jit(Complex(x_ra, y_ra)) == Complex(x, y)
119-
@test @jit(Complex(x_ra, y)) == Complex(x, y)
120-
@test @jit(Complex(x, y_ra)) == Complex(x, y)
121-
@test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0))
122-
end
102+
@testset "complex reduction" begin
103+
x = randn(ComplexF32, 10, 10)
104+
x_ra = Reactant.to_rarray(x)
105+
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
106+
end
123107

108+
@testset "create complex numbers" begin
109+
x = randn(ComplexF32)
110+
x_ra = Reactant.to_rarray(x; track_numbers=true)
111+
@test @jit(Complex(x_ra)) == x_ra
112+
113+
x = randn(Float32)
114+
y = randn(Float64)
115+
x_ra = Reactant.to_rarray(x; track_numbers=true)
116+
y_ra = Reactant.to_rarray(y; track_numbers=true)
117+
@test @jit(Complex(x_ra, y_ra)) == Complex(x, y)
118+
@test @jit(Complex(x_ra, y)) == Complex(x, y)
119+
@test @jit(Complex(x, y_ra)) == Complex(x, y)
120+
@test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0))
121+
end
124122
end

test/ops.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ using SpecialFunctions: SpecialFunctions
1111
@test [1.0, 1.0] @jit Ops.abs(x)
1212

1313
if !contains(string(Reactant.devices()[1]), "TPU")
14-
x = Reactant.to_rarray([
15-
3.0+4im -3.0+4im
16-
3.0-4im -3.0-4im
17-
])
18-
@test [
19-
5.0 5.0
20-
5.0 5.0
21-
] @jit Ops.abs(x)
14+
x = Reactant.to_rarray([
15+
3.0+4im -3.0+4im
16+
3.0-4im -3.0-4im
17+
])
18+
@test [
19+
5.0 5.0
20+
5.0 5.0
21+
] @jit Ops.abs(x)
2222
end
2323
end
2424

@@ -98,17 +98,17 @@ end
9898
@test transpose(cholesky(Array(x)).U) @jit g2(x)
9999

100100
if !contains(string(Reactant.devices()[1]), "TPU")
101-
x = Reactant.to_rarray(
102-
[
103-
10.0+0.0im 2.0-3.0im 3.0-4.0im
104-
2.0+3.0im 5.0+0.0im 3.0-2.0im
105-
3.0+4.0im 3.0+2.0im 9.0+0.0im
106-
],
107-
)
101+
x = Reactant.to_rarray(
102+
[
103+
10.0+0.0im 2.0-3.0im 3.0-4.0im
104+
2.0+3.0im 5.0+0.0im 3.0-2.0im
105+
3.0+4.0im 3.0+2.0im 9.0+0.0im
106+
],
107+
)
108108

109-
@test cholesky(Array(x)).U @jit g1(x)
110-
@test adjoint(cholesky(Array(x)).U) @jit g2(x)
111-
end
109+
@test cholesky(Array(x)).U @jit g1(x)
110+
@test adjoint(cholesky(Array(x)).U) @jit g2(x)
111+
end
112112
end
113113

114114
@testset "clamp" begin
@@ -145,15 +145,15 @@ end
145145
end
146146

147147
if !contains(string(Reactant.devices()[1]), "TPU")
148-
@testset "complex" begin
149-
x = Reactant.to_rarray(1.1; track_numbers=true)
150-
y = Reactant.to_rarray(2.2; track_numbers=true)
151-
@test 1.1 + 2.2im @jit Ops.complex(x, y)
152-
153-
x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4])
154-
y = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8])
155-
@test [1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] @jit Ops.complex(x, y)
156-
end
148+
@testset "complex" begin
149+
x = Reactant.to_rarray(1.1; track_numbers=true)
150+
y = Reactant.to_rarray(2.2; track_numbers=true)
151+
@test 1.1 + 2.2im @jit Ops.complex(x, y)
152+
153+
x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4])
154+
y = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8])
155+
@test [1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] @jit Ops.complex(x, y)
156+
end
157157
end
158158

159159
@testset "constant" begin

0 commit comments

Comments
 (0)