Skip to content

Commit 8897299

Browse files
committed
test: more test fixes
1 parent 117fe86 commit 8897299

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

test/autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ end
255255

256256
@testset "seed" begin
257257
x = Reactant.to_rarray(rand(2, 2))
258-
st = (; rng=Reactant.ConcreteRNG())
258+
st = (; rng=Reactant.ReactantRNG())
259259

260260
@test begin
261261
hlo = @code_hlo gradient_fn(x, st)

test/basic.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ end
416416

417417
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
418418
# complex f64 not supported on tpu
419-
a = Reactant.to_rarray(ones(CT, 2))
420-
b = Reactant.to_rarray(ones(CT, 2))
421419
@test begin
420+
a = Reactant.to_rarray(ones(CT, 2))
421+
b = Reactant.to_rarray(ones(CT, 2))
422422
c = Reactant.compile(+, (a, b))(a, b)
423423
c == ones(CT, 2) + ones(CT, 2)
424424
end broken = CT != ComplexF32 && RunningOnTPU
@@ -987,11 +987,12 @@ end
987987
@test Array(x) Array(y) ./ 2
988988
end
989989

990-
@test "HLO Cost Analysis" begin
990+
@testset "HLO Cost Analysis" begin
991991
x_ra = Reactant.to_rarray(rand(4, 4))
992992
mul_comp = @compile x_ra * x_ra
993-
@test Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties broken =
994-
RunningOnTPU
993+
@test begin
994+
Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties
995+
end broken = RunningOnTPU
995996
end
996997

997998
function fractional_idx(times, t)

test/indexing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using LinearAlgebra, Reactant, Test
22

3+
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
4+
35
function update_on_copy(x)
46
y = x[1:2, 2:4, :]
57
y[1:1, 1:1, :] = ones(1, 1, 3)

test/ops.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ const RunningOnAppleX86 = Sys.isapple() && Sys.ARCH === :x86_64
1717
3.0+4im -3.0+4im
1818
3.0-4im -3.0-4im
1919
])
20-
@test [
21-
5.0 5.0
22-
5.0 5.0
23-
] @jit Ops.abs(x) broken = RunningOnTPU
20+
@test begin
21+
[
22+
5.0 5.0
23+
5.0 5.0
24+
] @jit(Ops.abs(x))
25+
end broken = RunningOnTPU
2426
end
2527

2628
@testset "add" begin
@@ -40,7 +42,7 @@ end
4042
b = Reactant.to_rarray([
4143
9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im
4244
])
43-
@test Array(a) .+ Array(b) @jit Ops.add(a, b) broken = RunningOnTPU
45+
@test Array(a) .+ Array(b) @jit(Ops.add(a, b)) broken = RunningOnTPU
4446
end
4547

4648
@testset "after_all" begin

0 commit comments

Comments
 (0)