Skip to content

Commit 7537a68

Browse files
committed
test: mark tests as broken/skip
1 parent 00c92a9 commit 7537a68

File tree

13 files changed

+436
-539
lines changed

13 files changed

+436
-539
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
- os: linux-x86-ct6e-180-4tpu
6565
version: "1.11"
6666
assertions: false
67-
test_group: core
67+
test_group: all
6868
runtime: "IFRT"
6969
- os: ubuntu-24.04
7070
version: "1.10"

src/accelerators/TPU.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function download_libtpu_if_needed(path=nothing)
4343
zip_file_path = joinpath(path, "tpu.zip")
4444
tmp_dir = joinpath(path, "tmp")
4545
Downloads.download(
46-
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250727+nightly-py3-none-manylinux_2_31_x86_64.whl",
46+
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250811+nightly-py3-none-manylinux_2_31_x86_64.whl",
4747
zip_file_path,
4848
)
4949
run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`)

test/autodiff.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,18 @@ end
195195
@test res2 4 * 3 * 3.1^2
196196
end
197197

198-
if !contains(string(Reactant.devices()[1]), "TPU")
199-
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
198+
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
199+
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
200+
@test begin
200201
a = ones(ComplexF64, 2, 2)
201202
b = 2.0 * ones(ComplexF64, 2, 2)
202203
a_re = Reactant.to_rarray(a)
203204
b_re = Reactant.to_rarray(b)
204-
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
205-
@test begin
206-
res = @jit df(a_re, b_re) # before, this segfaulted
207-
(res.val 4ones(2, 2)) &&
208-
(res.derivs[1] 4ones(2, 2)) &&
209-
(res.derivs[2] 2ones(2, 2))
210-
end
211-
end
205+
res = @jit df(a_re, b_re) # before, this segfaulted
206+
(res.val 4ones(2, 2)) &&
207+
(res.derivs[1] 4ones(2, 2)) &&
208+
(res.derivs[2] 2ones(2, 2))
209+
end skip = contains(string(Reactant.devices()[1]), "TPU")
212210
end
213211

214212
@testset "onehot" begin
@@ -257,7 +255,7 @@ end
257255

258256
@testset "seed" begin
259257
x = Reactant.to_rarray(rand(2, 2))
260-
st = (; rng=Reactant.ConcreteRNG())
258+
st = (; rng=Reactant.ReactantRNG())
261259

262260
@test begin
263261
hlo = @code_hlo gradient_fn(x, st)

test/basic.jl

Lines changed: 86 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
using Reactant
2-
using Test
3-
using Enzyme
4-
using Statistics
5-
using Random
1+
using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils
62
Random.seed!(123)
73

8-
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
4+
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
95

10-
using InteractiveUtils
6+
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
117

128
@testset "2D sum" begin
139
x = rand(2, 10)
@@ -418,16 +414,6 @@ end
418414
@test eltype(f(y)) == eltype(x)
419415
end
420416

421-
@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
422-
# complex f64 not supported on tpu
423-
if CT == ComplexF32 || !contains(string(Reactant.devices()[1]), "TPU")
424-
a = Reactant.to_rarray(ones(CT, 2))
425-
b = Reactant.to_rarray(ones(CT, 2))
426-
c = Reactant.compile(+, (a, b))(a, b)
427-
@test c == ones(CT, 2) + ones(CT, 2)
428-
end
429-
end
430-
431417
@testset "Scalars" begin
432418
@testset "Only Scalars" begin
433419
x = (3, 3.14)
@@ -784,20 +770,20 @@ end
784770
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
785771
@test @jit(isfinite.(x)) == [true, false, false, false, false]
786772

787-
if !contains(string(Reactant.devices()[1]), "TPU")
773+
@test begin
788774
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
789-
@test @jit(isfinite.(x)) == [true, false, false, false, false]
790-
end
775+
@jit(isfinite.(x)) == [true, false, false, false, false]
776+
end skip = RunningOnTPU
791777
end
792778

793779
@testset "isnan" begin
794780
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
795781
@test @jit(isnan.(x)) == [false, true, false, false, true]
796782

797-
if !contains(string(Reactant.devices()[1]), "TPU")
783+
@test begin
798784
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
799-
@test @jit(isnan.(x)) == [false, true, false, false, true]
800-
end
785+
@jit(isnan.(x)) == [false, true, false, false, true]
786+
end skip = RunningOnTPU
801787
end
802788

803789
@testset "isnan/isfinite" begin
@@ -820,11 +806,10 @@ end
820806
b = [6.6, -2.2, -8.8, 4.4, -10.1]
821807

822808
expected_mod = mod.(a, b)
823-
if !contains(string(Reactant.devices()[1]), "TPU")
824-
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
825-
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
826-
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
827-
end
809+
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod broken =
810+
RunningOnTPU
811+
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod broken = RunningOnTPU
812+
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod broken = RunningOnTPU
828813

829814
expected_rem = rem.(a, b)
830815
@test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
@@ -838,22 +823,19 @@ end
838823
end
839824
end
840825

841-
if !contains(string(Reactant.devices()[1]), "TPU")
842-
@testset "signbit" begin
843-
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
844-
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
845-
end
826+
@testset "signbit" begin
827+
@testset "$(typeof(x))" for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
828+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken =
829+
RunningOnTPU && eltype(x) == Float64
846830
end
847831
end
848832

849-
if !contains(string(Reactant.devices()[1]), "TPU")
850-
@testset "copysign" begin
851-
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
852-
# Make sure also the return type is correct
853-
@test Reactant.to_number(
854-
@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))
855-
) === copysign(a, b)
856-
end
833+
@testset "copysign" begin
834+
@testset "$(typeof(a)) $(typeof(b))" for a in (-3.14, -2, 0.0, 2.71, 42),
835+
b in (-7, -0.57, -0.0, 1, 3.14)
836+
# Make sure also the return type is correct
837+
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b))))
838+
copysign(a, b) broken = RunningOnTPU && eltype(b) == Float64
857839
end
858840
end
859841

@@ -949,13 +931,11 @@ end
949931
ra[:a] (2.7 * 2) * ones(4)
950932
end
951933

952-
if !contains(string(Reactant.devices()[1]), "TPU")
953-
@testset "@code_xla" begin
954-
x_ra = Reactant.to_rarray(ones(4))
955-
hlo = repr(@code_xla(sin.(x_ra)))
956-
@test contains(hlo, "HloModule")
957-
@test contains(hlo, "sine")
958-
end
934+
@testset "@code_xla" begin
935+
x_ra = Reactant.to_rarray(ones(Float32, 4))
936+
hlo = repr(@code_xla(sin.(x_ra)))
937+
@test contains(hlo, "HloModule")
938+
@test contains(hlo, "sine")
959939
end
960940

961941
@testset "Raise keyword" begin
@@ -999,14 +979,12 @@ end
999979
@test Array(x) Array(y) ./ 2
1000980
end
1001981

1002-
if !contains(string(Reactant.devices()[1]), "TPU")
1003-
@testset "Hlo Cost Analysis" begin
1004-
x_ra = Reactant.to_rarray(rand(4, 4))
1005-
mul_comp = @compile x_ra * x_ra
1006-
cost = Reactant.XLA.cost_analysis(mul_comp)
1007-
1008-
@test cost isa Reactant.XLA.HloCostAnalysisProperties
1009-
end
982+
@testset "HLO Cost Analysis" begin
983+
x_ra = Reactant.to_rarray(rand(4, 4))
984+
mul_comp = @compile x_ra * x_ra
985+
@test begin
986+
Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties
987+
end broken = RunningOnTPU
1010988
end
1011989

1012990
function fractional_idx(times, t)
@@ -1140,32 +1118,30 @@ end
11401118
end
11411119
end
11421120

1143-
if !contains(string(Reactant.devices()[1]), "TPU")
1144-
@testset "Dump MLIR modules" begin
1145-
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1146-
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
1121+
@testset "Dump MLIR modules" begin
1122+
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1123+
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
11471124

1148-
mktempdir() do dir
1149-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1150-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1151-
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1152-
for mod in readdir(dir; join=true)
1153-
@test contains(read(mod, String), "hlo.sine")
1154-
end
1155-
end
1156-
1157-
mktempdir() do dir
1158-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1159-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1160-
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1161-
# Make sure we don't save anything to file when compilation is
1162-
# successful and `DUMP_MLIR_ALWAYS=false`.
1163-
@test isempty(readdir(dir; join=true))
1125+
mktempdir() do dir
1126+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1127+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1128+
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1129+
for mod in readdir(dir; join=true)
1130+
@test contains(read(mod, String), "hlo.sine")
11641131
end
1132+
end
11651133

1166-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1167-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
1134+
mktempdir() do dir
1135+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1136+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1137+
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1138+
# Make sure we don't save anything to file when compilation is
1139+
# successful and `DUMP_MLIR_ALWAYS=false`.
1140+
@test isempty(readdir(dir; join=true))
11681141
end
1142+
1143+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1144+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
11691145
end
11701146

11711147
@testset "Allocator Stats" begin
@@ -1291,40 +1267,38 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12911267
end cumprod(b; dims=3)
12921268
end
12931269

1294-
if !contains(string(Reactant.devices()[1]), "TPU")
1295-
@testset "accumulate" begin
1296-
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1297-
accumulate(accum_fn, a; init=0.0f0)
1298-
1299-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1300-
accumulate(accum_fn, b; dims=1, init=0.0f0)
1301-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1302-
accumulate(accum_fn, b; dims=2, init=0.0f0)
1303-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1304-
accumulate(accum_fn, b; dims=3, init=0.0f0)
1305-
1306-
@test begin
1307-
z = similar(a_ra)
1308-
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1309-
z
1310-
end accumulate(accum_fn, a; init=0.0f0)
1311-
1312-
@test begin
1313-
z = similar(b_ra)
1314-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1315-
z
1316-
end accumulate(accum_fn, b; dims=1, init=0.0f0)
1317-
@test begin
1318-
z = similar(b_ra)
1319-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1320-
z
1321-
end accumulate(accum_fn, b; dims=2, init=0.0f0)
1322-
@test begin
1323-
z = similar(b_ra)
1324-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1325-
z
1326-
end accumulate(accum_fn, b; dims=3, init=0.0f0)
1327-
end
1270+
@testset "accumulate" begin
1271+
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1272+
accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU
1273+
1274+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1275+
accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
1276+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1277+
accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
1278+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1279+
accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU
1280+
1281+
@test begin
1282+
z = similar(a_ra)
1283+
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1284+
z
1285+
end accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU
1286+
1287+
@test begin
1288+
z = similar(b_ra)
1289+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1290+
z
1291+
end accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
1292+
@test begin
1293+
z = similar(b_ra)
1294+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1295+
z
1296+
end accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
1297+
@test begin
1298+
z = similar(b_ra)
1299+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1300+
z
1301+
end accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU
13281302
end
13291303
end
13301304

0 commit comments

Comments
 (0)