@@ -560,50 +560,52 @@ end
560
560
@test [2 1 ; 4 3 ] == @jit g2 (x)
561
561
end
562
562
563
- @testset " rng_bit_generator" begin
564
- genInt32 (seed) = Ops. rng_bit_generator (Int32, seed, [2 , 4 ])
565
- genInt64 (seed) = Ops. rng_bit_generator (Int64, seed, [2 , 4 ])
566
- genUInt64 (seed) = Ops. rng_bit_generator (UInt64, seed, [2 , 4 ])
567
- genFloat32 (seed) = Ops. rng_bit_generator (Float32, seed, [2 , 4 ])
568
- genFloat64 (seed) = Ops. rng_bit_generator (Float64, seed, [2 , 4 ])
569
-
570
- @testset for (alg, sz) in
571
- [(" DEFAULT" , 2 ), (" PHILOX" , 2 ), (" PHILOX" , 3 ), (" THREE_FRY" , 2 )]
572
- seed = Reactant. to_rarray (zeros (UInt64, sz))
573
-
574
- res = @jit genInt32 (seed)
575
- @test res. output_state != = seed
576
- @test size (res. output_state) == (sz,)
577
- @test res. output isa ConcreteRArray{Int32,2 }
578
- @test size (res. output) == (2 , 4 )
579
-
580
- seed = res. output_state
581
- res = @jit genInt64 (seed)
582
- @test res. output_state != = seed
583
- @test size (res. output_state) == (sz,)
584
- @test res. output isa ConcreteRArray{Int64,2 }
585
- @test size (res. output) == (2 , 4 )
586
-
587
- seed = res. output_state
588
- res = @jit genUInt64 (seed)
589
- @test res. output_state != = seed
590
- @test size (res. output_state) == (sz,)
591
- @test res. output isa ConcreteRArray{UInt64,2 }
592
- @test size (res. output) == (2 , 4 )
593
-
594
- seed = res. output_state
595
- res = @jit genFloat32 (seed)
596
- @test res. output_state != = seed
597
- @test size (res. output_state) == (sz,)
598
- @test res. output isa ConcreteRArray{Float32,2 }
599
- @test size (res. output) == (2 , 4 )
600
-
601
- seed = res. output_state
602
- res = @jit genFloat64 (seed)
603
- @test res. output_state != = seed
604
- @test size (res. output_state) == (sz,)
605
- @test res. output isa ConcreteRArray{Float64,2 }
606
- @test size (res. output) == (2 , 4 )
563
+ if ! contains (string (Reactant. devices ()[1 ]), " TPU" )
564
+ @testset " rng_bit_generator" begin
565
+ genInt32 (seed) = Ops. rng_bit_generator (Int32, seed, [2 , 4 ])
566
+ genInt64 (seed) = Ops. rng_bit_generator (Int64, seed, [2 , 4 ])
567
+ genUInt64 (seed) = Ops. rng_bit_generator (UInt64, seed, [2 , 4 ])
568
+ genFloat32 (seed) = Ops. rng_bit_generator (Float32, seed, [2 , 4 ])
569
+ genFloat64 (seed) = Ops. rng_bit_generator (Float64, seed, [2 , 4 ])
570
+
571
+ @testset for (alg, sz) in
572
+ [(" DEFAULT" , 2 ), (" PHILOX" , 2 ), (" PHILOX" , 3 ), (" THREE_FRY" , 2 )]
573
+ seed = Reactant. to_rarray (zeros (UInt64, sz))
574
+
575
+ res = @jit genInt32 (seed)
576
+ @test res. output_state != = seed
577
+ @test size (res. output_state) == (sz,)
578
+ @test res. output isa ConcreteRArray{Int32,2 }
579
+ @test size (res. output) == (2 , 4 )
580
+
581
+ seed = res. output_state
582
+ res = @jit genInt64 (seed)
583
+ @test res. output_state != = seed
584
+ @test size (res. output_state) == (sz,)
585
+ @test res. output isa ConcreteRArray{Int64,2 }
586
+ @test size (res. output) == (2 , 4 )
587
+
588
+ seed = res. output_state
589
+ res = @jit genUInt64 (seed)
590
+ @test res. output_state != = seed
591
+ @test size (res. output_state) == (sz,)
592
+ @test res. output isa ConcreteRArray{UInt64,2 }
593
+ @test size (res. output) == (2 , 4 )
594
+
595
+ seed = res. output_state
596
+ res = @jit genFloat32 (seed)
597
+ @test res. output_state != = seed
598
+ @test size (res. output_state) == (sz,)
599
+ @test res. output isa ConcreteRArray{Float32,2 }
600
+ @test size (res. output) == (2 , 4 )
601
+
602
+ seed = res. output_state
603
+ res = @jit genFloat64 (seed)
604
+ @test res. output_state != = seed
605
+ @test size (res. output_state) == (sz,)
606
+ @test res. output isa ConcreteRArray{Float64,2 }
607
+ @test size (res. output) == (2 , 4 )
608
+ end
607
609
end
608
610
end
609
611
@@ -1225,7 +1227,8 @@ end
1225
1227
x_ra = Reactant. to_rarray (randn (Float32, 6 , 6 ))
1226
1228
lu_ra, ipiv, perm, info = @jit Ops. lu (x_ra)
1227
1229
1228
- @test @jit (recon_from_lu (lu_ra)) ≈ @jit (getindex (x_ra, perm, :))
1230
+ @test @jit (recon_from_lu (lu_ra)) ≈ @jit (getindex (x_ra, perm, :)) atol = 1e-5 rtol =
1231
+ 1e-2
1229
1232
end
1230
1233
1231
1234
@testset " batched" begin
@@ -1236,7 +1239,8 @@ end
1236
1239
@test size (perm) == (4 , 3 , 6 )
1237
1240
@test size (info) == (4 , 3 )
1238
1241
1239
- @test @jit (recon_from_lu (lu_ra)) ≈ @jit (apply_permutation (x_ra, perm))
1242
+ @test @jit (recon_from_lu (lu_ra)) ≈ @jit (apply_permutation (x_ra, perm)) atol = 1e-5 rtol =
1243
+ 1e-2
1240
1244
end
1241
1245
end
1242
1246
0 commit comments