@@ -3,120 +3,82 @@ using Reactant
3
3
4
4
const RunningOnTPU = contains (string (Reactant. devices ()[1 ]), " TPU" )
5
5
6
- @testset " conj" begin
7
- @testset " $(typeof (x)) " for x in (1.0 , 1.0 + 2.0im )
8
- x_concrete = Reactant. to_rarray (x)
9
- @test only (@jit (conj (x_concrete))) == conj (x) broken = RunningOnTPU
10
- end
11
-
12
- @testset " $(typeof (x)) " for x in (
13
- fill (1.0 + 2.0im ),
14
- fill (1.0 ),
15
- [1.0 + 2.0im ; 3.0 + 4.0im ],
16
- [1.0 ; 3.0 ],
17
- [1.0 + 2.0im 3.0 + 4.0im ],
18
- [1.0 2.0 ],
19
- [1.0 + 2.0im 3.0 + 4.0im ; 5.0 + 6.0im 7.0 + 8.0im ],
20
- [1.0 3.0 ; 5.0 7.0 ],
21
- )
22
- x_concrete = Reactant. to_rarray (x)
23
- @test @jit (conj (x_concrete)) == conj (x) broken = RunningOnTPU
24
- end
6
+ @testset " Complex runtime: $CT " for CT in (ComplexF32, ComplexF64)
7
+ @test begin
8
+ a = Reactant. to_rarray (ones (CT, 2 ))
9
+ b = Reactant. to_rarray (ones (CT, 2 ))
10
+ c = Reactant. compile (+ , (a, b))(a, b)
11
+ c == ones (CT, 2 ) + ones (CT, 2 )
12
+ end skip = CT == ComplexF64 && RunningOnTPU
25
13
end
26
14
27
- @testset " conj!" begin
28
- @testset " $(typeof (x)) " for x in (
29
- fill (1.0 + 2.0im ),
30
- fill (1.0 ),
31
- [1.0 + 2.0im ; 3.0 + 4.0im ],
32
- [1.0 ; 3.0 ],
33
- [1.0 + 2.0im 3.0 + 4.0im ],
34
- [1.0 2.0 ],
35
- [1.0 + 2.0im 3.0 + 4.0im ; 5.0 + 6.0im 7.0 + 8.0im ],
36
- [1.0 3.0 ; 5.0 7.0 ],
37
- )
38
- x_concrete = Reactant. to_rarray (x)
39
- @test @jit (conj! (x_concrete)) == conj (x) broken = RunningOnTPU
40
- @test x_concrete == conj (x) broken = RunningOnTPU
41
- end
42
- end
15
+ const SCALAR_LIST = (1.0 , 1.0 + 2.0im )
43
16
44
- @testset " real" begin
45
- @testset " $(typeof (x)) " for x in (1.0 , 1.0 + 2.0im )
46
- x_concrete = Reactant. to_rarray (x)
47
- @test only (@jit (real (x_concrete))) == real (x) broken = RunningOnTPU
48
- end
49
-
50
- @testset " $(typeof (x)) " for x in (
51
- fill (1.0 + 2.0im ),
52
- fill (1.0 ),
53
- [1.0 + 2.0im ; 3.0 + 4.0im ],
54
- [1.0 ; 3.0 ],
55
- [1.0 + 2.0im 3.0 + 4.0im ],
56
- [1.0 2.0 ],
57
- [1.0 + 2.0im 3.0 + 4.0im ; 5.0 + 6.0im 7.0 + 8.0im ],
58
- [1.0 3.0 ; 5.0 7.0 ],
59
- )
60
- x_concrete = Reactant. to_rarray (x)
61
- @test @jit (real (x_concrete)) == real (x) broken = RunningOnTPU
62
- end
63
- end
17
+ const ARRAY_LIST = (
18
+ fill (1.0 + 2.0im ),
19
+ fill (1.0 ),
20
+ [1.0 + 2.0im ; 3.0 + 4.0im ],
21
+ [1.0 ; 3.0 ],
22
+ [1.0 + 2.0im 3.0 + 4.0im ],
23
+ [1.0 2.0 ],
24
+ [1.0 + 2.0im 3.0 + 4.0im ; 5.0 + 6.0im 7.0 + 8.0im ],
25
+ [1.0 3.0 ; 5.0 7.0 ],
26
+ )
64
27
65
- @testset " imag" begin
66
- @testset " $(typeof (x)) " for x in (1.0 , 1.0 + 2.0im )
67
- x_concrete = Reactant. to_rarray (x)
68
- @test only (@jit (imag (x_concrete))) == imag (x)
28
+ @testset " $(string (fn)) " for fn in (conj, conj!, real, imag)
29
+ if ! endswith (string (fn), " !" )
30
+ @testset " $(typeof (x)) " for x in SCALAR_LIST
31
+ @test begin
32
+ x_concrete = Reactant. to_rarray (x)
33
+ only (@jit (fn (x_concrete))) == fn (x)
34
+ end skip = RunningOnTPU && eltype (x) == ComplexF64
35
+ end
69
36
end
70
37
71
- @testset " $(typeof (x)) " for x in (
72
- fill (1.0 + 2.0im ),
73
- fill (1.0 ),
74
- [1.0 + 2.0im ; 3.0 + 4.0im ],
75
- [1.0 ; 3.0 ],
76
- [1.0 + 2.0im 3.0 + 4.0im ],
77
- [1.0 2.0 ],
78
- [1.0 + 2.0im 3.0 + 4.0im ; 5.0 + 6.0im 7.0 + 8.0im ],
79
- [1.0 3.0 ; 5.0 7.0 ],
80
- )
81
- x_concrete = Reactant. to_rarray (x)
82
- @test @jit (imag (x_concrete)) == imag (x) broken = RunningOnTPU
38
+ @testset " $(typeof (x)) " for x in ARRAY_LIST
39
+ @test begin
40
+ x_concrete = Reactant. to_rarray (x)
41
+ @jit (fn (x_concrete)) == fn (x)
42
+ end skip = RunningOnTPU && eltype (x) == ComplexF64
83
43
end
84
44
end
85
45
86
46
@testset " abs: $T " for T in (Float32, ComplexF32)
87
47
x = randn (T, 10 )
88
48
x_concrete = Reactant. to_rarray (x)
89
- @test @jit (abs .(x_concrete)) ≈ abs .(x) broken = RunningOnTPU
49
+ @test @jit (abs .(x_concrete)) ≈ abs .(x)
90
50
end
91
51
92
52
@testset " promote_to Complex" begin
93
- x = 1.0 + 2.0im
53
+ x = ComplexF32 ( 1.0 + 2.0im )
94
54
y = ConcreteRNumber (x)
95
55
96
56
f = Reactant. compile ((y,)) do z
97
- z + Reactant. TracedUtils. promote_to (Reactant. TracedRNumber{ComplexF64}, 1.0 - 3.0im )
57
+ z + Reactant. TracedUtils. promote_to (
58
+ Reactant. TracedRNumber{ComplexF32}, ComplexF32 (1.0 - 3.0im )
59
+ )
98
60
end
99
61
100
- @test isapprox (f (y), 2.0 - 1.0im ) broken = RunningOnTPU
62
+ @test isapprox (f (y), ComplexF32 ( 2.0 - 1.0im ))
101
63
end
102
64
103
65
@testset " complex reduction" begin
104
66
x = randn (ComplexF32, 10 , 10 )
105
67
x_ra = Reactant. to_rarray (x)
106
- @test @jit (sum (abs2, x_ra)) ≈ sum (abs2, x) broken = RunningOnTPU
68
+ @test @jit (sum (abs2, x_ra)) ≈ sum (abs2, x)
107
69
end
108
70
109
71
@testset " create complex numbers" begin
110
72
x = randn (ComplexF32)
111
73
x_ra = Reactant. to_rarray (x; track_numbers= true )
112
- @test @jit (Complex (x_ra)) == x_ra broken = RunningOnTPU
74
+ @test @jit (Complex (x_ra)) == x_ra
113
75
114
76
x = randn (Float32)
115
77
y = randn (Float64)
116
78
x_ra = Reactant. to_rarray (x; track_numbers= true )
117
79
y_ra = Reactant. to_rarray (y; track_numbers= true )
118
- @test @jit (Complex (x_ra, y_ra)) == Complex (x, y) broken = RunningOnTPU
119
- @test @jit (Complex (x_ra, y)) == Complex (x, y) broken = RunningOnTPU
120
- @test @jit (Complex (x, y_ra)) == Complex (x, y) broken = RunningOnTPU
121
- @test @jit (Complex (x_ra)) == Complex (x) == @jit (Complex (x_ra, 0 )) broken = RunningOnTPU
80
+ @test @jit (Complex (x_ra, y_ra)) == Complex (x, y) skip = RunningOnTPU
81
+ @test @jit (Complex (x_ra, y)) == Complex (x, y) skip = RunningOnTPU
82
+ @test @jit (Complex (x, y_ra)) == Complex (x, y) skip = RunningOnTPU
83
+ @test @jit (Complex (x_ra)) == Complex (x) == @jit (Complex (x_ra, 0 ))
122
84
end
0 commit comments