Skip to content

test: mark tests as broken instead of skipping them #1536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ jobs:
assertions: false
test_group: core
runtime: "IFRT"
- os: linux-x86-ct6e-180-4tpu
version: "1.11"
assertions: false
test_group: integration
runtime: "IFRT"
- os: linux-x86-ct6e-180-4tpu
version: "1.11"
assertions: false
test_group: neural_networks
runtime: "IFRT"
- os: ubuntu-24.04
version: "1.10"
assertions: true
Expand Down
2 changes: 1 addition & 1 deletion src/accelerators/TPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function download_libtpu_if_needed(path=nothing)
zip_file_path = joinpath(path, "tpu.zip")
tmp_dir = joinpath(path, "tmp")
Downloads.download(
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250727+nightly-py3-none-manylinux_2_31_x86_64.whl",
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250811+nightly-py3-none-manylinux_2_31_x86_64.whl",
zip_file_path,
)
run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`)
Expand Down
11 changes: 7 additions & 4 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ function __init__()
XLA_REACTANT_GPU_MEM_FRACTION[] = parse(
Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"]
)
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] maxlog =
1
if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1")
end
Expand All @@ -141,16 +142,18 @@ function __init__()
XLA_REACTANT_GPU_PREALLOCATE[] = parse(
Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"]
)
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] maxlog =
1
end

if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
global_state.local_gpu_device_ids =
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids maxlog =
1
end

@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME maxlog = 1

@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
Expand Down
20 changes: 9 additions & 11 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,18 @@ end
@test res2 4 * 3 * 3.1^2
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
@test begin
a = ones(ComplexF64, 2, 2)
b = 2.0 * ones(ComplexF64, 2, 2)
a_re = Reactant.to_rarray(a)
b_re = Reactant.to_rarray(b)
df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y)
@test begin
res = @jit df(a_re, b_re) # before, this segfaulted
(res.val 4ones(2, 2)) &&
(res.derivs[1] 4ones(2, 2)) &&
(res.derivs[2] 2ones(2, 2))
end
end
res = @jit df(a_re, b_re) # before, this segfaulted
(res.val 4ones(2, 2)) &&
(res.derivs[1] 4ones(2, 2)) &&
(res.derivs[2] 2ones(2, 2))
end skip = contains(string(Reactant.devices()[1]), "TPU")
end

@testset "onehot" begin
Expand Down Expand Up @@ -257,7 +255,7 @@ end

@testset "seed" begin
x = Reactant.to_rarray(rand(2, 2))
st = (; rng=Reactant.ConcreteRNG())
st = (; rng=Reactant.ReactantRNG())

@test begin
hlo = @code_hlo gradient_fn(x, st)
Expand Down
198 changes: 86 additions & 112 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
using Reactant
using Test
using Enzyme
using Statistics
using Random
using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils
Random.seed!(123)

fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")

using InteractiveUtils
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))

@testset "2D sum" begin
x = rand(2, 10)
Expand Down Expand Up @@ -418,16 +414,6 @@ end
@test eltype(f(y)) == eltype(x)
end

@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64)
# complex f64 not supported on tpu
if CT == ComplexF32 || !contains(string(Reactant.devices()[1]), "TPU")
a = Reactant.to_rarray(ones(CT, 2))
b = Reactant.to_rarray(ones(CT, 2))
c = Reactant.compile(+, (a, b))(a, b)
@test c == ones(CT, 2) + ones(CT, 2)
end
end

@testset "Scalars" begin
@testset "Only Scalars" begin
x = (3, 3.14)
Expand Down Expand Up @@ -784,20 +770,20 @@ end
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
@test @jit(isfinite.(x)) == [true, false, false, false, false]

if !contains(string(Reactant.devices()[1]), "TPU")
@test begin
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
@test @jit(isfinite.(x)) == [true, false, false, false, false]
end
@jit(isfinite.(x)) == [true, false, false, false, false]
end skip = RunningOnTPU
end

@testset "isnan" begin
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
@test @jit(isnan.(x)) == [false, true, false, false, true]

if !contains(string(Reactant.devices()[1]), "TPU")
@test begin
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
@test @jit(isnan.(x)) == [false, true, false, false, true]
end
@jit(isnan.(x)) == [false, true, false, false, true]
end skip = RunningOnTPU
end

@testset "isnan/isfinite" begin
Expand All @@ -820,11 +806,10 @@ end
b = [6.6, -2.2, -8.8, 4.4, -10.1]

expected_mod = mod.(a, b)
if !contains(string(Reactant.devices()[1]), "TPU")
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
end
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod broken =
RunningOnTPU
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod broken = RunningOnTPU
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod broken = RunningOnTPU

expected_rem = rem.(a, b)
@test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
Expand All @@ -838,22 +823,19 @@ end
end
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "signbit" begin
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
end
@testset "signbit" begin
@testset "$(typeof(x))" for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken =
RunningOnTPU && eltype(x) == Float64
end
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "copysign" begin
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
# Make sure also the return type is correct
@test Reactant.to_number(
@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))
) === copysign(a, b)
end
@testset "copysign" begin
@testset "$(typeof(a)) $(typeof(b))" for a in (-3.14, -2, 0.0, 2.71, 42),
b in (-7, -0.57, -0.0, 1, 3.14)
# Make sure also the return type is correct
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b))))
copysign(a, b) broken = RunningOnTPU && eltype(b) == Float64
end
end

Expand Down Expand Up @@ -949,13 +931,11 @@ end
ra[:a] (2.7 * 2) * ones(4)
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "@code_xla" begin
x_ra = Reactant.to_rarray(ones(4))
hlo = repr(@code_xla(sin.(x_ra)))
@test contains(hlo, "HloModule")
@test contains(hlo, "sine")
end
@testset "@code_xla" begin
x_ra = Reactant.to_rarray(ones(Float32, 4))
hlo = repr(@code_xla(sin.(x_ra)))
@test contains(hlo, "HloModule")
@test contains(hlo, "sine")
end

@testset "Raise keyword" begin
Expand Down Expand Up @@ -999,14 +979,12 @@ end
@test Array(x) Array(y) ./ 2
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "Hlo Cost Analysis" begin
x_ra = Reactant.to_rarray(rand(4, 4))
mul_comp = @compile x_ra * x_ra
cost = Reactant.XLA.cost_analysis(mul_comp)

@test cost isa Reactant.XLA.HloCostAnalysisProperties
end
@testset "HLO Cost Analysis" begin
x_ra = Reactant.to_rarray(rand(4, 4))
mul_comp = @compile x_ra * x_ra
@test begin
Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties
end broken = RunningOnTPU
end

function fractional_idx(times, t)
Expand Down Expand Up @@ -1140,32 +1118,30 @@ end
end
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "Dump MLIR modules" begin
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
@testset "Dump MLIR modules" begin
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]

mktempdir() do dir
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
@compile sin.(Reactant.to_rarray(Float32[1.0]))
for mod in readdir(dir; join=true)
@test contains(read(mod, String), "hlo.sine")
end
end

mktempdir() do dir
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
@compile exp.(Reactant.to_rarray(Float32[1.0]))
# Make sure we don't save anything to file when compilation is
# successful and `DUMP_MLIR_ALWAYS=false`.
@test isempty(readdir(dir; join=true))
mktempdir() do dir
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
@compile sin.(Reactant.to_rarray(Float32[1.0]))
for mod in readdir(dir; join=true)
@test contains(read(mod, String), "hlo.sine")
end
end

Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
mktempdir() do dir
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
@compile exp.(Reactant.to_rarray(Float32[1.0]))
# Make sure we don't save anything to file when compilation is
# successful and `DUMP_MLIR_ALWAYS=false`.
@test isempty(readdir(dir; join=true))
end

Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
end

@testset "Allocator Stats" begin
Expand Down Expand Up @@ -1291,40 +1267,38 @@ accum_fn(x, y) = abs2(x) + abs2(y)
end cumprod(b; dims=3)
end

if !contains(string(Reactant.devices()[1]), "TPU")
@testset "accumulate" begin
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
accumulate(accum_fn, a; init=0.0f0)

@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
accumulate(accum_fn, b; dims=1, init=0.0f0)
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
accumulate(accum_fn, b; dims=2, init=0.0f0)
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
accumulate(accum_fn, b; dims=3, init=0.0f0)

@test begin
z = similar(a_ra)
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
z
end accumulate(accum_fn, a; init=0.0f0)

@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
z
end accumulate(accum_fn, b; dims=1, init=0.0f0)
@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
z
end accumulate(accum_fn, b; dims=2, init=0.0f0)
@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
z
end accumulate(accum_fn, b; dims=3, init=0.0f0)
end
@testset "accumulate" begin
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU

@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU

@test begin
z = similar(a_ra)
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
z
end accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU

@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
z
end accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
z
end accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
@test begin
z = similar(b_ra)
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
z
end accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU
end
end

Expand Down
Loading
Loading