Skip to content

Commit 0ea96f2

Browse files
committed
test: more test fixes
1 parent ba34ba1 commit 0ea96f2

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ end
248248

249249
@testset "seed" begin
250250
x = Reactant.to_rarray(rand(2, 2))
251-
st = (; rng=Reactant.ConcreteRNG())
251+
st = (; rng=Reactant.ReactantRNG())
252252

253253
@test begin
254254
hlo = @code_hlo gradient_fn(x, st)

test/basic.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ end
416416

417417
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
418418
# complex f64 not supported on tpu
419-
a = Reactant.to_rarray(ones(CT, 2))
420-
b = Reactant.to_rarray(ones(CT, 2))
421419
@test begin
420+
a = Reactant.to_rarray(ones(CT, 2))
421+
b = Reactant.to_rarray(ones(CT, 2))
422422
c = Reactant.compile(+, (a, b))(a, b)
423423
c == ones(CT, 2) + ones(CT, 2)
424424
end broken = CT != ComplexF32 && RunningOnTPU
@@ -990,8 +990,9 @@ end
990990
@test "HLO Cost Analysis" begin
991991
x_ra = Reactant.to_rarray(rand(4, 4))
992992
mul_comp = @compile x_ra * x_ra
993-
@test Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties broken =
994-
RunningOnTPU
993+
@test begin
994+
Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties
995+
end broken = RunningOnTPU
995996
end
996997

997998
function fractional_idx(times, t)

test/indexing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using LinearAlgebra, Reactant, Test
22

3+
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
4+
35
function update_on_copy(x)
46
y = x[1:2, 2:4, :]
57
y[1:1, 1:1, :] = ones(1, 1, 3)

test/ops.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ const RunningOnAppleX86 = Sys.isapple() && Sys.ARCH === :x86_64
1717
3.0+4im -3.0+4im
1818
3.0-4im -3.0-4im
1919
])
20-
@test [
21-
5.0 5.0
22-
5.0 5.0
23-
] @jit Ops.abs(x) broken = RunningOnTPU
20+
@test begin
21+
[
22+
5.0 5.0
23+
5.0 5.0
24+
] @jit(Ops.abs(x))
25+
end broken = RunningOnTPU
2426
end
2527

2628
@testset "add" begin

0 commit comments

Comments
 (0)