1
- using Reactant
2
- using Test
3
- using Enzyme
4
- using Statistics
5
- using Random
1
+ using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils
6
2
Random. seed! (123 )
7
3
8
- fastmax (x :: AbstractArray{T} ) where {T} = reduce (max, x; dims = 1 , init = float (T)( - Inf ) )
4
+ const RunningOnTPU = contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
9
5
10
- using InteractiveUtils
6
+ fastmax (x :: AbstractArray{T} ) where {T} = reduce (max, x; dims = 1 , init = float (T)( - Inf ))
11
7
12
8
@testset " 2D sum" begin
13
9
x = rand (2 , 10 )
418
414
@test eltype (f (y)) == eltype (x)
419
415
end
420
416
421
- @testset " Complex runtime: $CT " for CT in (ComplexF32, ComplexF64)
422
- # complex f64 not supported on tpu
423
- if CT == ComplexF32 || ! contains (string (Reactant. devices ()[1 ]), " TPU" )
424
- a = Reactant. to_rarray (ones (CT, 2 ))
425
- b = Reactant. to_rarray (ones (CT, 2 ))
426
- c = Reactant. compile (+ , (a, b))(a, b)
427
- @test c == ones (CT, 2 ) + ones (CT, 2 )
428
- end
429
- end
430
-
431
417
@testset " Scalars" begin
432
418
@testset " Only Scalars" begin
433
419
x = (3 , 3.14 )
@@ -784,20 +770,20 @@ end
784
770
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ])
785
771
@test @jit (isfinite .(x)) == [true , false , false , false , false ]
786
772
787
- if ! contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
773
+ @test begin
788
774
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ] .* im)
789
- @test @ jit (isfinite .(x)) == [true , false , false , false , false ]
790
- end
775
+ @jit (isfinite .(x)) == [true , false , false , false , false ]
776
+ end skip = RunningOnTPU
791
777
end
792
778
793
779
@testset " isnan" begin
794
780
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ])
795
781
@test @jit (isnan .(x)) == [false , true , false , false , true ]
796
782
797
- if ! contains ( string (Reactant . devices ()[ 1 ]), " TPU " )
783
+ @test begin
798
784
x = Reactant. to_rarray ([1.0 , NaN , Inf , - Inf , NaN ] .* im)
799
- @test @ jit (isnan .(x)) == [false , true , false , false , true ]
800
- end
785
+ @jit (isnan .(x)) == [false , true , false , false , true ]
786
+ end skip = RunningOnTPU
801
787
end
802
788
803
789
@testset " isnan/isfinite" begin
@@ -820,11 +806,10 @@ end
820
806
b = [6.6 , - 2.2 , - 8.8 , 4.4 , - 10.1 ]
821
807
822
808
expected_mod = mod .(a, b)
823
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
824
- @test @jit (mod .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_mod
825
- @test @jit (mod .(a, Reactant. to_rarray (b))) ≈ expected_mod
826
- @test @jit (mod .(Reactant. to_rarray (a), b)) ≈ expected_mod
827
- end
809
+ @test @jit (mod .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_mod broken =
810
+ RunningOnTPU
811
+ @test @jit (mod .(a, Reactant. to_rarray (b))) ≈ expected_mod broken = RunningOnTPU
812
+ @test @jit (mod .(Reactant. to_rarray (a), b)) ≈ expected_mod broken = RunningOnTPU
828
813
829
814
expected_rem = rem .(a, b)
830
815
@test @jit (rem .(Reactant. to_rarray (a), Reactant. to_rarray (b))) ≈ expected_rem
@@ -838,22 +823,19 @@ end
838
823
end
839
824
end
840
825
841
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
842
- @testset " signbit" begin
843
- for x in (- 4 , - 3.14 , - 0.0f0 , 0.0 , 0 , 5 , 6.28f0 )
844
- @test @jit (signbit (ConcreteRNumber (x))) == signbit (x)
845
- end
826
+ @testset " signbit" begin
827
+ @testset " $(typeof (x)) " for x in (- 4 , - 3.14 , - 0.0f0 , 0.0 , 0 , 5 , 6.28f0 )
828
+ @test @jit (signbit (ConcreteRNumber (x))) == signbit (x) broken =
829
+ RunningOnTPU && eltype (x) == Float64
846
830
end
847
831
end
848
832
849
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
850
- @testset " copysign" begin
851
- for a in (- 3.14 , - 2 , 0.0 , 2.71 , 42 ), b in (- 7 , - 0.57 , - 0.0 , 1 , 3.14 )
852
- # Make sure also the return type is correct
853
- @test Reactant. to_number (
854
- @jit (copysign (ConcreteRNumber (a), ConcreteRNumber (b)))
855
- ) === copysign (a, b)
856
- end
833
+ @testset " copysign" begin
834
+ @testset " $(typeof (a)) $(typeof (b)) " for a in (- 3.14 , - 2 , 0.0 , 2.71 , 42 ),
835
+ b in (- 7 , - 0.57 , - 0.0 , 1 , 3.14 )
836
+ # Make sure also the return type is correct
837
+ @test Reactant. to_number (@jit (copysign (ConcreteRNumber (a), ConcreteRNumber (b)))) ≈
838
+ copysign (a, b) broken = RunningOnTPU && eltype (b) == Float64
857
839
end
858
840
end
859
841
@@ -949,13 +931,11 @@ end
949
931
ra[:a ] ≈ (2.7 * 2 ) * ones (4 )
950
932
end
951
933
952
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
953
- @testset " @code_xla" begin
954
- x_ra = Reactant. to_rarray (ones (4 ))
955
- hlo = repr (@code_xla (sin .(x_ra)))
956
- @test contains (hlo, " HloModule" )
957
- @test contains (hlo, " sine" )
958
- end
934
+ @testset " @code_xla" begin
935
+ x_ra = Reactant. to_rarray (ones (Float32, 4 ))
936
+ hlo = repr (@code_xla (sin .(x_ra)))
937
+ @test contains (hlo, " HloModule" )
938
+ @test contains (hlo, " sine" )
959
939
end
960
940
961
941
@testset " Raise keyword" begin
@@ -999,14 +979,12 @@ end
999
979
@test Array (x) ≈ Array (y) ./ 2
1000
980
end
1001
981
1002
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
1003
- @testset " Hlo Cost Analysis" begin
1004
- x_ra = Reactant. to_rarray (rand (4 , 4 ))
1005
- mul_comp = @compile x_ra * x_ra
1006
- cost = Reactant. XLA. cost_analysis (mul_comp)
1007
-
1008
- @test cost isa Reactant. XLA. HloCostAnalysisProperties
1009
- end
982
+ @testset " HLO Cost Analysis" begin
983
+ x_ra = Reactant. to_rarray (rand (4 , 4 ))
984
+ mul_comp = @compile x_ra * x_ra
985
+ @test begin
986
+ Reactant. XLA. cost_analysis (mul_comp) isa Reactant. XLA. HloCostAnalysisProperties
987
+ end broken = RunningOnTPU
1010
988
end
1011
989
1012
990
function fractional_idx (times, t)
@@ -1140,32 +1118,30 @@ end
1140
1118
end
1141
1119
end
1142
1120
1143
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
1144
- @testset " Dump MLIR modules" begin
1145
- always_old = Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[]
1146
- dir_old = Reactant. MLIR. IR. DUMP_MLIR_DIR[]
1121
+ @testset " Dump MLIR modules" begin
1122
+ always_old = Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[]
1123
+ dir_old = Reactant. MLIR. IR. DUMP_MLIR_DIR[]
1147
1124
1148
- mktempdir () do dir
1149
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = true
1150
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1151
- @compile sin .(Reactant. to_rarray (Float32[1.0 ]))
1152
- for mod in readdir (dir; join= true )
1153
- @test contains (read (mod, String), " hlo.sine" )
1154
- end
1155
- end
1156
-
1157
- mktempdir () do dir
1158
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = false
1159
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1160
- @compile exp .(Reactant. to_rarray (Float32[1.0 ]))
1161
- # Make sure we don't save anything to file when compilation is
1162
- # successful and `DUMP_MLIR_ALWAYS=false`.
1163
- @test isempty (readdir (dir; join= true ))
1125
+ mktempdir () do dir
1126
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = true
1127
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1128
+ @compile sin .(Reactant. to_rarray (Float32[1.0 ]))
1129
+ for mod in readdir (dir; join= true )
1130
+ @test contains (read (mod, String), " hlo.sine" )
1164
1131
end
1132
+ end
1165
1133
1166
- Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = always_old
1167
- Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir_old
1134
+ mktempdir () do dir
1135
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = false
1136
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir
1137
+ @compile exp .(Reactant. to_rarray (Float32[1.0 ]))
1138
+ # Make sure we don't save anything to file when compilation is
1139
+ # successful and `DUMP_MLIR_ALWAYS=false`.
1140
+ @test isempty (readdir (dir; join= true ))
1168
1141
end
1142
+
1143
+ Reactant. MLIR. IR. DUMP_MLIR_ALWAYS[] = always_old
1144
+ Reactant. MLIR. IR. DUMP_MLIR_DIR[] = dir_old
1169
1145
end
1170
1146
1171
1147
@testset " Allocator Stats" begin
@@ -1291,40 +1267,38 @@ accum_fn(x, y) = abs2(x) + abs2(y)
1291
1267
end ≈ cumprod (b; dims= 3 )
1292
1268
end
1293
1269
1294
- if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
1295
- @testset " accumulate" begin
1296
- @test @jit (accumulate (accum_fn, a_ra; init= 0.0f0 )) ≈
1297
- accumulate (accum_fn, a; init= 0.0f0 )
1298
-
1299
- @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 1 )) ≈
1300
- accumulate (accum_fn, b; dims= 1 , init= 0.0f0 )
1301
- @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 2 )) ≈
1302
- accumulate (accum_fn, b; dims= 2 , init= 0.0f0 )
1303
- @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 3 )) ≈
1304
- accumulate (accum_fn, b; dims= 3 , init= 0.0f0 )
1305
-
1306
- @test begin
1307
- z = similar (a_ra)
1308
- @jit (accumulate! (accum_fn, z, a_ra; init= 0.0f0 ))
1309
- z
1310
- end ≈ accumulate (accum_fn, a; init= 0.0f0 )
1311
-
1312
- @test begin
1313
- z = similar (b_ra)
1314
- @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 1 ))
1315
- z
1316
- end ≈ accumulate (accum_fn, b; dims= 1 , init= 0.0f0 )
1317
- @test begin
1318
- z = similar (b_ra)
1319
- @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 2 ))
1320
- z
1321
- end ≈ accumulate (accum_fn, b; dims= 2 , init= 0.0f0 )
1322
- @test begin
1323
- z = similar (b_ra)
1324
- @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 3 ))
1325
- z
1326
- end ≈ accumulate (accum_fn, b; dims= 3 , init= 0.0f0 )
1327
- end
1270
+ @testset " accumulate" begin
1271
+ @test @jit (accumulate (accum_fn, a_ra; init= 0.0f0 )) ≈
1272
+ accumulate (accum_fn, a; init= 0.0f0 ) broken = RunningOnTPU
1273
+
1274
+ @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 1 )) ≈
1275
+ accumulate (accum_fn, b; dims= 1 , init= 0.0f0 ) broken = RunningOnTPU
1276
+ @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 2 )) ≈
1277
+ accumulate (accum_fn, b; dims= 2 , init= 0.0f0 ) broken = RunningOnTPU
1278
+ @test @jit (accumulate (accum_fn, b_ra; init= 0.0f0 , dims= 3 )) ≈
1279
+ accumulate (accum_fn, b; dims= 3 , init= 0.0f0 ) broken = RunningOnTPU
1280
+
1281
+ @test begin
1282
+ z = similar (a_ra)
1283
+ @jit (accumulate! (accum_fn, z, a_ra; init= 0.0f0 ))
1284
+ z
1285
+ end ≈ accumulate (accum_fn, a; init= 0.0f0 ) broken = RunningOnTPU
1286
+
1287
+ @test begin
1288
+ z = similar (b_ra)
1289
+ @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 1 ))
1290
+ z
1291
+ end ≈ accumulate (accum_fn, b; dims= 1 , init= 0.0f0 ) broken = RunningOnTPU
1292
+ @test begin
1293
+ z = similar (b_ra)
1294
+ @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 2 ))
1295
+ z
1296
+ end ≈ accumulate (accum_fn, b; dims= 2 , init= 0.0f0 ) broken = RunningOnTPU
1297
+ @test begin
1298
+ z = similar (b_ra)
1299
+ @jit (accumulate! (accum_fn, z, b_ra; init= 0.0f0 , dims= 3 ))
1300
+ z
1301
+ end ≈ accumulate (accum_fn, b; dims= 3 , init= 0.0f0 ) broken = RunningOnTPU
1328
1302
end
1329
1303
end
1330
1304
0 commit comments