1
- using Reactant
2
- using Test
3
- using Enzyme
4
- using Statistics
5
- using Random
1
+ using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils
6
2
Random. seed! (123 )
7
3
8
- fastmax (x :: AbstractArray{T} ) where {T} = reduce (max, x; dims = 1 , init = float (T)( - Inf ) )
4
+ const RunningOnTPU = contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
9
5
10
- using InteractiveUtils
6
+ fastmax (x :: AbstractArray{T} ) where {T} = reduce (max, x; dims = 1 , init = float (T)( - Inf ))
11
7
12
8
@testset " 2D sum" begin
13
9
x = rand (2 , 10 )
@@ -420,12 +416,12 @@ end
420
416
421
417
@testset " Complex runtime: $CT " for CT in (ComplexF32, ComplexF64)
422
418
# complex f64 not supported on tpu
423
- if CT == ComplexF32 || ! contains ( string ( Reactant. devices ()[ 1 ]), " TPU " )
424
- a = Reactant. to_rarray (ones (CT, 2 ))
425
- b = Reactant . to_rarray ( ones (CT, 2 ))
419
+ a = Reactant. to_rarray ( ones (CT, 2 ) )
420
+ b = Reactant. to_rarray (ones (CT, 2 ))
421
+ @test begin
426
422
c = Reactant. compile (+ , (a, b))(a, b)
427
- @test c == ones (CT, 2 ) + ones (CT, 2 )
428
- end
423
+ c == ones (CT, 2 ) + ones (CT, 2 )
424
+ end broken = CT != ComplexF32 && RunningOnTPU
429
425
end
430
426
431
427
@testset " Scalars" begin
@@ -784,20 +780,20 @@ end
784
780
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ])
785
781
@test @jit (isfinite .(x)) == [true , false , false , false , false ]
786
782
787
- if ! contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
783
+ @test begin
788
784
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ] .* im)
789
- @test @ jit (isfinite .(x)) == [true , false , false , false , false ]
790
- end
785
+ @jit (isfinite .(x)) == [true , false , false , false , false ]
786
+ end broken = RunningOnTPU
791
787
end
792
788
793
789
@testset " isnan" begin
794
790
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ])
795
791
@test @jit (isnan .(x)) == [false , true , false , false , true ]
796
792
797
- if ! contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
793
+ @test begin
798
794
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ] .* im)
799
- @test @ jit (isnan .(x)) == [false , true , false , false , true ]
800
- end
795
+ @jit (isnan .(x)) == [false , true , false , false , true ]
796
+ end broken = RunningOnTPU
801
797
end
802
798
803
799
@testset " isnan/isfinite" begin
@@ -820,11 +816,10 @@ end
820
816
b = [6.6 , - 2.2 , - 8.8 , 4.4 , - 10.1 ]
821
817
822
818
expected_mod = mod .(a, b)
823
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
824
- @test @jit (mod .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_mod
825
- @test @jit (mod .(a, Reactant. to_rarray (b))) ≈ expected_mod
826
- @test @jit (mod .(Reactant. to_rarray (a), b)) ≈ expected_mod
827
- end
819
+ @test @jit (mod .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_mod broken =
820
+ RunningOnTPU
821
+ @test @jit (mod .(a, Reactant. to_rarray (b))) ≈ expected_mod broken = RunningOnTPU
822
+ @test @jit (mod .(Reactant. to_rarray (a), b)) ≈ expected_mod broken = RunningOnTPU
828
823
829
824
expected_rem = rem .(a, b)
830
825
@test @jit (rem .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_rem
@@ -838,22 +833,17 @@ end
838
833
end
839
834
end
840
835
841
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
842
- @testset " signbit" begin
843
- for x in (- 4 , - 3.14 , - 0.0f0 , 0.0 , 0 , 5 , 6.28f0 )
844
- @test @jit (signbit (ConcreteRNumber (x))) == signbit (x)
845
- end
836
+ @testset " signbit" begin
837
+ for x in (- 4 , - 3.14 , - 0.0f0 , 0.0 , 0 , 5 , 6.28f0 )
838
+ @test @jit (signbit (ConcreteRNumber (x))) == signbit (x) broken = RunningOnTPU
846
839
end
847
840
end
848
841
849
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
850
- @testset " copysign" begin
851
- for a in (- 3.14 , - 2 , 0.0 , 2.71 , 42 ), b in (- 7 , - 0.57 , - 0.0 , 1 , 3.14 )
852
- # Make sure also the return type is correct
853
- @test Reactant. to_number (
854
- @jit (copysign (ConcreteRNumber (a), ConcreteRNumber (b)))
855
- ) === copysign (a, b)
856
- end
842
+ @testset " copysign" begin
843
+ for a in (- 3.14 , - 2 , 0.0 , 2.71 , 42 ), b in (- 7 , - 0.57 , - 0.0 , 1 , 3.14 )
844
+ # Make sure also the return type is correct
845
+ @test Reactant. to_number (@jit (copysign (ConcreteRNumber (a), ConcreteRNumber (b)))) ===
846
+ copysign (a, b) broken = RunningOnTPU
857
847
end
858
848
end
859
849
@@ -949,13 +939,11 @@ end
949
939
ra[:a ] ≈ (2.7 * 2 ) * ones (4 )
950
940
end
951
941
952
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
953
- @testset " @code_xla" begin
954
- x_ra = Reactant. to_rarray (ones (4 ))
955
- hlo = repr (@code_xla (sin .(x_ra)))
956
- @test contains (hlo, " HloModule" )
957
- @test contains (hlo, " sine" )
958
- end
942
+ @testset " @code_xla" begin
943
+ x_ra = Reactant. to_rarray (ones (4 ))
944
+ hlo = repr (@code_xla (sin .(x_ra)))
945
+ @test contains (hlo, " HloModule" ) broken = RunningOnTPU
946
+ @test contains (hlo, " sine" ) broken = RunningOnTPU
959
947
end
960
948
961
949
@testset " Raise keyword" begin
@@ -999,14 +987,11 @@ end
999
987
@test Array (x) ≈ Array (y) ./ 2
1000
988
end
1001
989
1002
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
1003
- @testset " Hlo Cost Analysis" begin
1004
- x_ra = Reactant. to_rarray (rand (4 , 4 ))
1005
- mul_comp = @compile x_ra * x_ra
1006
- cost = Reactant. XLA. cost_analysis (mul_comp)
1007
-
1008
- @test cost isa Reactant. XLA. HloCostAnalysisProperties
1009
- end
990
+ @test " HLO Cost Analysis" begin
991
+ x_ra = Reactant. to_rarray (rand (4 , 4 ))
992
+ mul_comp = @compile x_ra * x_ra
993
+ @test Reactant. XLA. cost_analysis (mul_comp) isa Reactant. XLA. HloCostAnalysisProperties broken =
994
+ RunningOnTPU
1010
995
end
1011
996
1012
997
function fractional_idx (times, t)
@@ -1140,32 +1125,30 @@ end
1140
1125
end
1141
1126
end
1142
1127
1143
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
1144
- @testset " Dump MLIR modules" begin
1145
- always_old = Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[]
1146
- dir_old = Reactant. MLIR. IR. DUMP_MLIR_DIR[]
1128
+ @testset " Dump MLIR modules" begin
1129
+ always_old = Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[]
1130
+ dir_old = Reactant. MLIR. IR. DUMP_MLIR_DIR[]
1147
1131
1148
- mktempdir () do dir
1149
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = true
1150
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1151
- @compile sin .(Reactant. to_rarray (Float32[1.0 ]))
1152
- for mod in readdir (dir; join= true )
1153
- @test contains (read (mod, String), " hlo.sine" )
1154
- end
1155
- end
1156
-
1157
- mktempdir () do dir
1158
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = false
1159
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1160
- @compile exp .(Reactant. to_rarray (Float32[1.0 ]))
1161
- # Make sure we don't save anything to file when compilation is
1162
- # successful and `DUMP_MLIR_ALWAYS=false`.
1163
- @test isempty (readdir (dir; join= true ))
1132
+ mktempdir () do dir
1133
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = true
1134
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1135
+ @compile sin .(Reactant. to_rarray (Float32[1.0 ]))
1136
+ for mod in readdir (dir; join= true )
1137
+ @test contains (read (mod, String), " hlo.sine" ) broken = RunningOnTPU
1164
1138
end
1139
+ end
1165
1140
1166
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = always_old
1167
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir_old
1141
+ mktempdir () do dir
1142
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = false
1143
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1144
+ @compile exp .(Reactant. to_rarray (Float32[1.0 ]))
1145
+ # Make sure we don't save anything to file when compilation is
1146
+ # successful and `DUMP_MLIR_ALWAYS=false`.
1147
+ @test isempty (readdir (dir; join= true )) broken = RunningOnTPU
1168
1148
end
1149
+
1150
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = always_old
1151
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir_old
1169
1152
end
1170
1153
1171
1154
@testset " Allocator Stats" begin
0 commit comments