Skip to content

Commit 99069e5

Browse files
committed
test: more tests work
1 parent 356ae0b commit 99069e5

File tree

5 files changed

+55
-102
lines changed

5 files changed

+55
-102
lines changed

test/autodiff.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,17 +196,17 @@ end
196196
end
197197

198198
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
199-
a = ones(ComplexF64, 2, 2)
200-
b = 2.0 * ones(ComplexF64, 2, 2)
201-
a_re = Reactant.to_rarray(a)
202-
b_re = Reactant.to_rarray(b)
203199
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
204200
@test begin
201+
a = ones(ComplexF64, 2, 2)
202+
b = 2.0 * ones(ComplexF64, 2, 2)
203+
a_re = Reactant.to_rarray(a)
204+
b_re = Reactant.to_rarray(b)
205205
res = @jit df(a_re, b_re) # before, this segfaulted
206206
(res.val 4ones(2, 2)) &&
207207
(res.derivs[1] 4ones(2, 2)) &&
208208
(res.derivs[2] 2ones(2, 2))
209-
end broken = contains(string(Reactant.devices()[1]), "TPU")
209+
end skip = contains(string(Reactant.devices()[1]), "TPU")
210210
end
211211

212212
@testset "onehot" begin

test/basic.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -414,16 +414,6 @@ end
414414
@test eltype(f(y)) == eltype(x)
415415
end
416416

417-
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
418-
# complex f64 not supported on tpu
419-
@test begin
420-
a = Reactant.to_rarray(ones(CT, 2))
421-
b = Reactant.to_rarray(ones(CT, 2))
422-
c = Reactant.compile(+, (a, b))(a, b)
423-
c == ones(CT, 2) + ones(CT, 2)
424-
end broken = CT != ComplexF32 && RunningOnTPU
425-
end
426-
427417
@testset "Scalars" begin
428418
@testset "Only Scalars" begin
429419
x = (3, 3.14)
@@ -783,7 +773,7 @@ end
783773
@test begin
784774
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
785775
@jit(isfinite.(x)) == [true, false, false, false, false]
786-
end broken = RunningOnTPU
776+
end skip = RunningOnTPU
787777
end
788778

789779
@testset "isnan" begin
@@ -793,7 +783,7 @@ end
793783
@test begin
794784
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
795785
@jit(isnan.(x)) == [false, true, false, false, true]
796-
end broken = RunningOnTPU
786+
end skip = RunningOnTPU
797787
end
798788

799789
@testset "isnan/isfinite" begin

test/complex.jl

Lines changed: 43 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,120 +3,82 @@ using Reactant
33

44
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
55

6-
@testset "conj" begin
7-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
8-
x_concrete = Reactant.to_rarray(x)
9-
@test only(@jit(conj(x_concrete))) == conj(x) broken = RunningOnTPU
10-
end
11-
12-
@testset "$(typeof(x))" for x in (
13-
fill(1.0 + 2.0im),
14-
fill(1.0),
15-
[1.0 + 2.0im; 3.0 + 4.0im],
16-
[1.0; 3.0],
17-
[1.0 + 2.0im 3.0 + 4.0im],
18-
[1.0 2.0],
19-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
20-
[1.0 3.0; 5.0 7.0],
21-
)
22-
x_concrete = Reactant.to_rarray(x)
23-
@test @jit(conj(x_concrete)) == conj(x) broken = RunningOnTPU
24-
end
6+
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
7+
@test begin
8+
a = Reactant.to_rarray(ones(CT, 2))
9+
b = Reactant.to_rarray(ones(CT, 2))
10+
c = Reactant.compile(+, (a, b))(a, b)
11+
c == ones(CT, 2) + ones(CT, 2)
12+
end skip = CT == ComplexF64 && RunningOnTPU
2513
end
2614

27-
@testset "conj!" begin
28-
@testset "$(typeof(x))" for x in (
29-
fill(1.0 + 2.0im),
30-
fill(1.0),
31-
[1.0 + 2.0im; 3.0 + 4.0im],
32-
[1.0; 3.0],
33-
[1.0 + 2.0im 3.0 + 4.0im],
34-
[1.0 2.0],
35-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
36-
[1.0 3.0; 5.0 7.0],
37-
)
38-
x_concrete = Reactant.to_rarray(x)
39-
@test @jit(conj!(x_concrete)) == conj(x) broken = RunningOnTPU
40-
@test x_concrete == conj(x) broken = RunningOnTPU
41-
end
42-
end
15+
const SCALAR_LIST = (1.0, 1.0 + 2.0im)
4316

44-
@testset "real" begin
45-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
46-
x_concrete = Reactant.to_rarray(x)
47-
@test only(@jit(real(x_concrete))) == real(x) broken = RunningOnTPU
48-
end
49-
50-
@testset "$(typeof(x))" for x in (
51-
fill(1.0 + 2.0im),
52-
fill(1.0),
53-
[1.0 + 2.0im; 3.0 + 4.0im],
54-
[1.0; 3.0],
55-
[1.0 + 2.0im 3.0 + 4.0im],
56-
[1.0 2.0],
57-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
58-
[1.0 3.0; 5.0 7.0],
59-
)
60-
x_concrete = Reactant.to_rarray(x)
61-
@test @jit(real(x_concrete)) == real(x) broken = RunningOnTPU
62-
end
63-
end
17+
const ARRAY_LIST = (
18+
fill(1.0 + 2.0im),
19+
fill(1.0),
20+
[1.0 + 2.0im; 3.0 + 4.0im],
21+
[1.0; 3.0],
22+
[1.0 + 2.0im 3.0 + 4.0im],
23+
[1.0 2.0],
24+
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
25+
[1.0 3.0; 5.0 7.0],
26+
)
6427

65-
@testset "imag" begin
66-
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
67-
x_concrete = Reactant.to_rarray(x)
68-
@test only(@jit(imag(x_concrete))) == imag(x)
28+
@testset "$(string(fn))" for fn in (conj, conj!, real, imag)
29+
if !endswith(string(fn), "!")
30+
@testset "$(typeof(x))" for x in SCALAR_LIST
31+
@test begin
32+
x_concrete = Reactant.to_rarray(x)
33+
only(@jit(fn(x_concrete))) == fn(x)
34+
end skip = RunningOnTPU && eltype(x) == ComplexF64
35+
end
6936
end
7037

71-
@testset "$(typeof(x))" for x in (
72-
fill(1.0 + 2.0im),
73-
fill(1.0),
74-
[1.0 + 2.0im; 3.0 + 4.0im],
75-
[1.0; 3.0],
76-
[1.0 + 2.0im 3.0 + 4.0im],
77-
[1.0 2.0],
78-
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
79-
[1.0 3.0; 5.0 7.0],
80-
)
81-
x_concrete = Reactant.to_rarray(x)
82-
@test @jit(imag(x_concrete)) == imag(x) broken = RunningOnTPU
38+
@testset "$(typeof(x))" for x in ARRAY_LIST
39+
@test begin
40+
x_concrete = Reactant.to_rarray(x)
41+
@jit(fn(x_concrete)) == fn(x)
42+
end skip = RunningOnTPU && eltype(x) == ComplexF64
8343
end
8444
end
8545

8646
@testset "abs: $T" for T in (Float32, ComplexF32)
8747
x = randn(T, 10)
8848
x_concrete = Reactant.to_rarray(x)
89-
@test @jit(abs.(x_concrete)) abs.(x) broken = RunningOnTPU
49+
@test @jit(abs.(x_concrete)) abs.(x)
9050
end
9151

9252
@testset "promote_to Complex" begin
93-
x = 1.0 + 2.0im
53+
x = ComplexF32(1.0 + 2.0im)
9454
y = ConcreteRNumber(x)
9555

9656
f = Reactant.compile((y,)) do z
97-
z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im)
57+
z + Reactant.TracedUtils.promote_to(
58+
Reactant.TracedRNumber{ComplexF32}, ComplexF32(1.0 - 3.0im)
59+
)
9860
end
9961

100-
@test isapprox(f(y), 2.0 - 1.0im) broken = RunningOnTPU
62+
@test isapprox(f(y), ComplexF32(2.0 - 1.0im))
10163
end
10264

10365
@testset "complex reduction" begin
10466
x = randn(ComplexF32, 10, 10)
10567
x_ra = Reactant.to_rarray(x)
106-
@test @jit(sum(abs2, x_ra)) sum(abs2, x) broken = RunningOnTPU
68+
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
10769
end
10870

10971
@testset "create complex numbers" begin
11072
x = randn(ComplexF32)
11173
x_ra = Reactant.to_rarray(x; track_numbers=true)
112-
@test @jit(Complex(x_ra)) == x_ra broken = RunningOnTPU
74+
@test @jit(Complex(x_ra)) == x_ra
11375

11476
x = randn(Float32)
11577
y = randn(Float64)
11678
x_ra = Reactant.to_rarray(x; track_numbers=true)
11779
y_ra = Reactant.to_rarray(y; track_numbers=true)
118-
@test @jit(Complex(x_ra, y_ra)) == Complex(x, y) broken = RunningOnTPU
119-
@test @jit(Complex(x_ra, y)) == Complex(x, y) broken = RunningOnTPU
120-
@test @jit(Complex(x, y_ra)) == Complex(x, y) broken = RunningOnTPU
121-
@test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0)) broken = RunningOnTPU
80+
@test @jit(Complex(x_ra, y_ra)) == Complex(x, y) skip = RunningOnTPU
81+
@test @jit(Complex(x_ra, y)) == Complex(x, y) skip = RunningOnTPU
82+
@test @jit(Complex(x, y_ra)) == Complex(x, y) skip = RunningOnTPU
83+
@test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0))
12284
end

test/optimize_comm.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ if length(addressable_devices) ≥ 8
3737
@test !contains(hlo, "all-gather")
3838
@test contains(hlo, "collective-permute")
3939

40+
rotate(x)
4041
@jit shardy_passes = :to_mhlo_shardings rotate(rx)
4142
@test all(x .== convert(Array, rx))
4243
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
1616
# @info "Layout tests finished"
1717
# @safetestset "Tracing" include("tracing.jl")
1818
# @info "Tracing tests finished"
19-
# @safetestset "Basic" include("basic.jl") # TODO: needs fixing -- stalling currently
20-
# @info "Basic tests finished"
21-
@safetestset "Constructor" include("constructor.jl")
22-
@info "Constructor tests finished"
19+
@safetestset "Basic" include("basic.jl")
20+
@info "Basic tests finished"
21+
# @safetestset "Constructor" include("constructor.jl")
22+
# @info "Constructor tests finished"
2323
@safetestset "Autodiff" include("autodiff.jl")
2424
@info "Autodiff tests finished"
2525
@safetestset "Complex" include("complex.jl")

0 commit comments

Comments
 (0)