Skip to content

Commit 104e6fa

Browse files
committed
test: mark tests as broken instead of skipping them
1 parent 76068e3 commit 104e6fa

File tree

6 files changed

+335
-410
lines changed

6 files changed

+335
-410
lines changed

test/autodiff.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,18 @@ end
195195
@test res2 4 * 3 * 3.1^2
196196
end
197197

198-
if !contains(string(Reactant.devices()[1]), "TPU")
199-
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
200-
a = ones(ComplexF64, 2, 2)
201-
b = 2.0 * ones(ComplexF64, 2, 2)
202-
a_re = Reactant.to_rarray(a)
203-
b_re = Reactant.to_rarray(b)
204-
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
205-
@test begin
206-
res = @jit df(a_re, b_re) # before, this segfaulted
207-
(res.val 4ones(2, 2)) &&
208-
(res.derivs[1] 4ones(2, 2)) &&
209-
(res.derivs[2] 2ones(2, 2))
210-
end
211-
end
198+
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
199+
a = ones(ComplexF64, 2, 2)
200+
b = 2.0 * ones(ComplexF64, 2, 2)
201+
a_re = Reactant.to_rarray(a)
202+
b_re = Reactant.to_rarray(b)
203+
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
204+
@test begin
205+
res = @jit df(a_re, b_re) # before, this segfaulted
206+
(res.val 4ones(2, 2)) &&
207+
(res.derivs[1] 4ones(2, 2)) &&
208+
(res.derivs[2] 2ones(2, 2))
209+
end broken=contains(string(Reactant.devices()[1]), "TPU")
212210
end
213211

214212
@testset "onehot" begin

test/basic.jl

Lines changed: 56 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
using Reactant
2-
using Test
3-
using Enzyme
4-
using Statistics
5-
using Random
1+
using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils
62
Random.seed!(123)
73

8-
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
4+
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
95

10-
using InteractiveUtils
6+
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
117

128
@testset "2D sum" begin
139
x = rand(2, 10)
@@ -420,12 +416,12 @@ end
420416

421417
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
422418
# complex f64 not supported on tpu
423-
if CT == ComplexF32 || !contains(string(Reactant.devices()[1]), "TPU")
424-
a = Reactant.to_rarray(ones(CT, 2))
425-
b = Reactant.to_rarray(ones(CT, 2))
419+
a = Reactant.to_rarray(ones(CT, 2))
420+
b = Reactant.to_rarray(ones(CT, 2))
421+
@test begin
426422
c = Reactant.compile(+, (a, b))(a, b)
427-
@test c == ones(CT, 2) + ones(CT, 2)
428-
end
423+
c == ones(CT, 2) + ones(CT, 2)
424+
end broken = CT != ComplexF32 && RunningOnTPU
429425
end
430426

431427
@testset "Scalars" begin
@@ -784,20 +780,20 @@ end
784780
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
785781
@test @jit(isfinite.(x)) == [true, false, false, false, false]
786782

787-
if !contains(string(Reactant.devices()[1]), "TPU")
783+
@test begin
788784
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
789-
@test @jit(isfinite.(x)) == [true, false, false, false, false]
790-
end
785+
@jit(isfinite.(x)) == [true, false, false, false, false]
786+
end broken = RunningOnTPU
791787
end
792788

793789
@testset "isnan" begin
794790
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
795791
@test @jit(isnan.(x)) == [false, true, false, false, true]
796792

797-
if !contains(string(Reactant.devices()[1]), "TPU")
793+
@test begin
798794
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
799-
@test @jit(isnan.(x)) == [false, true, false, false, true]
800-
end
795+
@jit(isnan.(x)) == [false, true, false, false, true]
796+
end broken = RunningOnTPU
801797
end
802798

803799
@testset "isnan/isfinite" begin
@@ -820,11 +816,10 @@ end
820816
b = [6.6, -2.2, -8.8, 4.4, -10.1]
821817

822818
expected_mod = mod.(a, b)
823-
if !contains(string(Reactant.devices()[1]), "TPU")
824-
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
825-
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
826-
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
827-
end
819+
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod broken =
820+
RunningOnTPU
821+
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod broken = RunningOnTPU
822+
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod broken = RunningOnTPU
828823

829824
expected_rem = rem.(a, b)
830825
@test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
@@ -838,22 +833,17 @@ end
838833
end
839834
end
840835

841-
if !contains(string(Reactant.devices()[1]), "TPU")
842-
@testset "signbit" begin
843-
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
844-
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
845-
end
836+
@testset "signbit" begin
837+
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
838+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken = RunningOnTPU
846839
end
847840
end
848841

849-
if !contains(string(Reactant.devices()[1]), "TPU")
850-
@testset "copysign" begin
851-
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
852-
# Make sure also the return type is correct
853-
@test Reactant.to_number(
854-
@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))
855-
) === copysign(a, b)
856-
end
842+
@testset "copysign" begin
843+
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
844+
# Make sure also the return type is correct
845+
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ===
846+
copysign(a, b) broken = RunningOnTPU
857847
end
858848
end
859849

@@ -949,13 +939,11 @@ end
949939
ra[:a] (2.7 * 2) * ones(4)
950940
end
951941

952-
if !contains(string(Reactant.devices()[1]), "TPU")
953-
@testset "@code_xla" begin
954-
x_ra = Reactant.to_rarray(ones(4))
955-
hlo = repr(@code_xla(sin.(x_ra)))
956-
@test contains(hlo, "HloModule")
957-
@test contains(hlo, "sine")
958-
end
942+
@testset "@code_xla" begin
943+
x_ra = Reactant.to_rarray(ones(4))
944+
hlo = repr(@code_xla(sin.(x_ra)))
945+
@test contains(hlo, "HloModule") broken = RunningOnTPU
946+
@test contains(hlo, "sine") broken = RunningOnTPU
959947
end
960948

961949
@testset "Raise keyword" begin
@@ -999,14 +987,11 @@ end
999987
@test Array(x) Array(y) ./ 2
1000988
end
1001989

1002-
if !contains(string(Reactant.devices()[1]), "TPU")
1003-
@testset "Hlo Cost Analysis" begin
1004-
x_ra = Reactant.to_rarray(rand(4, 4))
1005-
mul_comp = @compile x_ra * x_ra
1006-
cost = Reactant.XLA.cost_analysis(mul_comp)
1007-
1008-
@test cost isa Reactant.XLA.HloCostAnalysisProperties
1009-
end
990+
@test "HLO Cost Analysis" begin
991+
x_ra = Reactant.to_rarray(rand(4, 4))
992+
mul_comp = @compile x_ra * x_ra
993+
@test Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties broken =
994+
RunningOnTPU
1010995
end
1011996

1012997
function fractional_idx(times, t)
@@ -1140,32 +1125,30 @@ end
11401125
end
11411126
end
11421127

1143-
if !contains(string(Reactant.devices()[1]), "TPU")
1144-
@testset "Dump MLIR modules" begin
1145-
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1146-
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
1128+
@testset "Dump MLIR modules" begin
1129+
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1130+
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
11471131

1148-
mktempdir() do dir
1149-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1150-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1151-
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1152-
for mod in readdir(dir; join=true)
1153-
@test contains(read(mod, String), "hlo.sine")
1154-
end
1155-
end
1156-
1157-
mktempdir() do dir
1158-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1159-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1160-
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1161-
# Make sure we don't save anything to file when compilation is
1162-
# successful and `DUMP_MLIR_ALWAYS=false`.
1163-
@test isempty(readdir(dir; join=true))
1132+
mktempdir() do dir
1133+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1134+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1135+
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1136+
for mod in readdir(dir; join=true)
1137+
@test contains(read(mod, String), "hlo.sine") broken = RunningOnTPU
11641138
end
1139+
end
11651140

1166-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1167-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
1141+
mktempdir() do dir
1142+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1143+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1144+
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1145+
# Make sure we don't save anything to file when compilation is
1146+
# successful and `DUMP_MLIR_ALWAYS=false`.
1147+
@test isempty(readdir(dir; join=true)) broken = RunningOnTPU
11681148
end
1149+
1150+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1151+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
11691152
end
11701153

11711154
@testset "Allocator Stats" begin

0 commit comments

Comments
 (0)