diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index eeacc5611e..03fbb39513 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -66,6 +66,16 @@ jobs: assertions: false test_group: core runtime: "IFRT" + - os: linux-x86-ct6e-180-4tpu + version: "1.11" + assertions: false + test_group: integration + runtime: "IFRT" + - os: linux-x86-ct6e-180-4tpu + version: "1.11" + assertions: false + test_group: neural_networks + runtime: "IFRT" - os: ubuntu-24.04 version: "1.10" assertions: true diff --git a/src/accelerators/TPU.jl b/src/accelerators/TPU.jl index 1918032d55..7a6901b8b6 100644 --- a/src/accelerators/TPU.jl +++ b/src/accelerators/TPU.jl @@ -43,7 +43,7 @@ function download_libtpu_if_needed(path=nothing) zip_file_path = joinpath(path, "tpu.zip") tmp_dir = joinpath(path, "tmp") Downloads.download( - "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250727+nightly-py3-none-manylinux_2_31_x86_64.whl", + "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250811+nightly-py3-none-manylinux_2_31_x86_64.whl", zip_file_path, ) run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 57c74ca160..058e060ee3 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -131,7 +131,8 @@ function __init__() XLA_REACTANT_GPU_MEM_FRACTION[] = parse( Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"] ) - @debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] + @debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] maxlog = + 1 if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0 error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1") end @@ -141,16 +142,18 @@ function __init__() XLA_REACTANT_GPU_PREALLOCATE[] = parse( Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"] ) - @debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] + @debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] maxlog = + 1 end if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES") global_state.local_gpu_device_ids = parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ",")) - @debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids + @debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids maxlog = + 1 end - @debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME + @debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME maxlog = 1 @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid @ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid diff --git a/test/autodiff.jl b/test/autodiff.jl index a11ec900e4..4330c99345 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -195,20 +195,18 @@ end @test res2 ≈ 4 * 3 * 3.1^2 end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "Seed initialization of Complex arrays on matmul: Issue #593" begin +@testset "Seed initialization of Complex arrays on matmul: Issue #593" begin + df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y) + @test begin a = ones(ComplexF64, 2, 2) b = 2.0 * ones(ComplexF64, 2, 2) a_re = Reactant.to_rarray(a) b_re = Reactant.to_rarray(b) - df(x, y) = Enzyme.gradient(ReverseWithPrimal, *, x, y) - @test begin - res = @jit df(a_re, b_re) # before, this segfaulted - (res.val ≈ 4ones(2, 2)) && - (res.derivs[1] ≈ 4ones(2, 2)) && - (res.derivs[2] ≈ 2ones(2, 2)) - end - end + res = @jit df(a_re, b_re) # before, this segfaulted + (res.val ≈ 4ones(2, 2)) && + (res.derivs[1] ≈ 4ones(2, 2)) && + (res.derivs[2] ≈ 2ones(2, 2)) + end skip = contains(string(Reactant.devices()[1]), "TPU") end @testset "onehot" begin @@ -257,7 +255,7 @@ end @testset "seed" begin x = Reactant.to_rarray(rand(2, 2)) - st = (; rng=Reactant.ConcreteRNG()) + st = (; rng=Reactant.ReactantRNG()) @test begin hlo = @code_hlo gradient_fn(x, st) diff --git a/test/basic.jl b/test/basic.jl index 1cc3f66d8a..e0d2bc5af7 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,13 +1,9 @@ -using Reactant -using Test -using Enzyme -using Statistics -using Random +using Reactant, Test, Enzyme, Statistics, Random, InteractiveUtils Random.seed!(123) -fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") -using InteractiveUtils +fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf)) @testset "2D sum" begin x = rand(2, 10) @@ -418,16 +414,6 @@ end @test eltype(f(y)) == eltype(x) end -@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64) - # complex f64 not supported on tpu - if CT == ComplexF32 || !contains(string(Reactant.devices()[1]), "TPU") - a = Reactant.to_rarray(ones(CT, 2)) - b = Reactant.to_rarray(ones(CT, 2)) - c = Reactant.compile(+, (a, b))(a, b) - @test c == ones(CT, 2) + ones(CT, 2) - end -end - @testset "Scalars" begin @testset "Only Scalars" begin x = (3, 3.14) @@ -784,20 +770,20 @@ end x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) @test @jit(isfinite.(x)) == [true, false, false, false, false] - if !contains(string(Reactant.devices()[1]), "TPU") + @test begin x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) - @test @jit(isfinite.(x)) == [true, false, false, false, false] - end + @jit(isfinite.(x)) == [true, false, false, false, false] + end skip = RunningOnTPU end @testset "isnan" begin x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN]) @test @jit(isnan.(x)) == [false, true, false, false, true] - if !contains(string(Reactant.devices()[1]), "TPU") + @test begin x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im) - @test @jit(isnan.(x)) == [false, true, false, false, true] - end + @jit(isnan.(x)) == [false, true, false, false, true] + end skip = RunningOnTPU end @testset "isnan/isfinite" begin @@ -820,11 +806,10 @@ end b = [6.6, -2.2, -8.8, 4.4, -10.1] expected_mod = mod.(a, b) - if !contains(string(Reactant.devices()[1]), "TPU") - @test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod - @test @jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod - @test @jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod - end + @test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_mod broken = + RunningOnTPU + @test @jit(mod.(a, Reactant.to_rarray(b))) ≈ expected_mod broken = RunningOnTPU + @test @jit(mod.(Reactant.to_rarray(a), b)) ≈ expected_mod broken = RunningOnTPU expected_rem = rem.(a, b) @test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) ≈ expected_rem @@ -838,22 +823,19 @@ end end end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "signbit" begin - for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) - @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) - end +@testset "signbit" begin + @testset "$(typeof(x))" for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0) + @test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken = + RunningOnTPU && eltype(x) == Float64 end end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "copysign" begin - for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14) - # Make sure also the return type is correct - @test Reactant.to_number( - @jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b))) - ) === copysign(a, b) - end +@testset "copysign" begin + @testset "$(typeof(a)) $(typeof(b))" for a in (-3.14, -2, 0.0, 2.71, 42), + b in (-7, -0.57, -0.0, 1, 3.14) + # Make sure also the return type is correct + @test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ≈ + copysign(a, b) broken = RunningOnTPU && eltype(b) == Float64 end end @@ -949,13 +931,11 @@ end ra[:a] ≈ (2.7 * 2) * ones(4) end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "@code_xla" begin - x_ra = Reactant.to_rarray(ones(4)) - hlo = repr(@code_xla(sin.(x_ra))) - @test contains(hlo, "HloModule") - @test contains(hlo, "sine") - end +@testset "@code_xla" begin + x_ra = Reactant.to_rarray(ones(Float32, 4)) + hlo = repr(@code_xla(sin.(x_ra))) + @test contains(hlo, "HloModule") + @test contains(hlo, "sine") end @testset "Raise keyword" begin @@ -999,14 +979,12 @@ end @test Array(x) ≈ Array(y) ./ 2 end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "Hlo Cost Analysis" begin - x_ra = Reactant.to_rarray(rand(4, 4)) - mul_comp = @compile x_ra * x_ra - cost = Reactant.XLA.cost_analysis(mul_comp) - - @test cost isa Reactant.XLA.HloCostAnalysisProperties - end +@testset "HLO Cost Analysis" begin + x_ra = Reactant.to_rarray(rand(4, 4)) + mul_comp = @compile x_ra * x_ra + @test begin + Reactant.XLA.cost_analysis(mul_comp) isa Reactant.XLA.HloCostAnalysisProperties + end broken = RunningOnTPU end function fractional_idx(times, t) @@ -1140,32 +1118,30 @@ end end end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "Dump MLIR modules" begin - always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] - dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[] +@testset "Dump MLIR modules" begin + always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] + dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[] - mktempdir() do dir - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir - @compile sin.(Reactant.to_rarray(Float32[1.0])) - for mod in readdir(dir; join=true) - @test contains(read(mod, String), "hlo.sine") - end - end - - mktempdir() do dir - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir - @compile exp.(Reactant.to_rarray(Float32[1.0])) - # Make sure we don't save anything to file when compilation is - # successful and `DUMP_MLIR_ALWAYS=false`. - @test isempty(readdir(dir; join=true)) + mktempdir() do dir + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir + @compile sin.(Reactant.to_rarray(Float32[1.0])) + for mod in readdir(dir; join=true) + @test contains(read(mod, String), "hlo.sine") end + end - Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old - Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old + mktempdir() do dir + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir + @compile exp.(Reactant.to_rarray(Float32[1.0])) + # Make sure we don't save anything to file when compilation is + # successful and `DUMP_MLIR_ALWAYS=false`. + @test isempty(readdir(dir; join=true)) end + + Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old + Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old end @testset "Allocator Stats" begin @@ -1291,40 +1267,38 @@ accum_fn(x, y) = abs2(x) + abs2(y) end ≈ cumprod(b; dims=3) end - if !contains(string(Reactant.devices()[1]), "TPU") - @testset "accumulate" begin - @test @jit(accumulate(accum_fn, a_ra; init=0.0f0)) ≈ - accumulate(accum_fn, a; init=0.0f0) - - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1)) ≈ - accumulate(accum_fn, b; dims=1, init=0.0f0) - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2)) ≈ - accumulate(accum_fn, b; dims=2, init=0.0f0) - @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3)) ≈ - accumulate(accum_fn, b; dims=3, init=0.0f0) - - @test begin - z = similar(a_ra) - @jit(accumulate!(accum_fn, z, a_ra; init=0.0f0)) - z - end ≈ accumulate(accum_fn, a; init=0.0f0) - - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1)) - z - end ≈ accumulate(accum_fn, b; dims=1, init=0.0f0) - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2)) - z - end ≈ accumulate(accum_fn, b; dims=2, init=0.0f0) - @test begin - z = similar(b_ra) - @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3)) - z - end ≈ accumulate(accum_fn, b; dims=3, init=0.0f0) - end + @testset "accumulate" begin + @test @jit(accumulate(accum_fn, a_ra; init=0.0f0)) ≈ + accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU + + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1)) ≈ + accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2)) ≈ + accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU + @test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3)) ≈ + accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU + + @test begin + z = similar(a_ra) + @jit(accumulate!(accum_fn, z, a_ra; init=0.0f0)) + z + end ≈ accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU + + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1)) + z + end ≈ accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2)) + z + end ≈ accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU + @test begin + z = similar(b_ra) + @jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3)) + z + end ≈ accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU end end diff --git a/test/complex.jl b/test/complex.jl index 1430306af4..c7a1b5ba68 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,122 +1,84 @@ using Test using Reactant -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "conj" begin - @testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) - x_concrete = Reactant.to_rarray(x) - @test only(@jit(conj(x_concrete))) == conj(x) - end +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") - @testset "$(typeof(x))" for x in ( - fill(1.0 + 2.0im), - fill(1.0), - [1.0 + 2.0im; 3.0 + 4.0im], - [1.0; 3.0], - [1.0 + 2.0im 3.0 + 4.0im], - [1.0 2.0], - [1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], - [1.0 3.0; 5.0 7.0], - ) - x_concrete = Reactant.to_rarray(x) - @test @jit(conj(x_concrete)) == conj(x) - end - end +@testset "Complex runtime: $CT" for CT in (ComplexF32, ComplexF64) + @test begin + a = Reactant.to_rarray(ones(CT, 2)) + b = Reactant.to_rarray(ones(CT, 2)) + c = Reactant.compile(+, (a, b))(a, b) + c == ones(CT, 2) + ones(CT, 2) + end skip = CT == ComplexF64 && RunningOnTPU +end - @testset "conj!" begin - @testset "$(typeof(x))" for x in ( - fill(1.0 + 2.0im), - fill(1.0), - [1.0 + 2.0im; 3.0 + 4.0im], - [1.0; 3.0], - [1.0 + 2.0im 3.0 + 4.0im], - [1.0 2.0], - [1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], - [1.0 3.0; 5.0 7.0], - ) - x_concrete = Reactant.to_rarray(x) - @test @jit(conj!(x_concrete)) == conj(x) - @test x_concrete == conj(x) - end - end +const SCALAR_LIST = (1.0, 1.0 + 2.0im) - @testset "real" begin - @testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) - x_concrete = Reactant.to_rarray(x) - @test only(@jit(real(x_concrete))) == real(x) - end +const ARRAY_LIST = ( + fill(1.0 + 2.0im), + fill(1.0), + [1.0 + 2.0im; 3.0 + 4.0im], + [1.0; 3.0], + [1.0 + 2.0im 3.0 + 4.0im], + [1.0 2.0], + [1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], + [1.0 3.0; 5.0 7.0], +) - @testset "$(typeof(x))" for x in ( - fill(1.0 + 2.0im), - fill(1.0), - [1.0 + 2.0im; 3.0 + 4.0im], - [1.0; 3.0], - [1.0 + 2.0im 3.0 + 4.0im], - [1.0 2.0], - [1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], - [1.0 3.0; 5.0 7.0], - ) - x_concrete = Reactant.to_rarray(x) - @test @jit(real(x_concrete)) == real(x) +@testset "$(string(fn))" for fn in (conj, conj!, real, imag) + if !endswith(string(fn), "!") + @testset "$(typeof(x))" for x in SCALAR_LIST + @test begin + x_concrete = Reactant.to_rarray(x) + only(@jit(fn(x_concrete))) == fn(x) + end skip = RunningOnTPU && eltype(x) == ComplexF64 end end - @testset "imag" begin - @testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) + @testset "$(typeof(x))" for x in ARRAY_LIST + @test begin x_concrete = Reactant.to_rarray(x) - @test only(@jit(imag(x_concrete))) == imag(x) - end - - @testset "$(typeof(x))" for x in ( - fill(1.0 + 2.0im), - fill(1.0), - [1.0 + 2.0im; 3.0 + 4.0im], - [1.0; 3.0], - [1.0 + 2.0im 3.0 + 4.0im], - [1.0 2.0], - [1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], - [1.0 3.0; 5.0 7.0], - ) - x_concrete = Reactant.to_rarray(x) - @test @jit(imag(x_concrete)) == imag(x) - end + @jit(fn(x_concrete)) == fn(x) + end skip = RunningOnTPU && eltype(x) == ComplexF64 end +end - @testset "abs: $T" for T in (Float32, ComplexF32) - x = randn(T, 10) - x_concrete = Reactant.to_rarray(x) - @test @jit(abs.(x_concrete)) ≈ abs.(x) - end +@testset "abs: $T" for T in (Float32, ComplexF32) + x = randn(T, 10) + x_concrete = Reactant.to_rarray(x) + @test @jit(abs.(x_concrete)) ≈ abs.(x) +end - @testset "promote_to Complex" begin - x = 1.0 + 2.0im - y = ConcreteRNumber(x) +@testset "promote_to Complex" begin + x = ComplexF32(1.0 + 2.0im) + y = ConcreteRNumber(x) - f = Reactant.compile((y,)) do z - z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) - end - - @test isapprox(f(y), 2.0 - 1.0im) + f = Reactant.compile((y,)) do z + z + Reactant.TracedUtils.promote_to( + Reactant.TracedRNumber{ComplexF32}, ComplexF32(1.0 - 3.0im) + ) end - @testset "complex reduction" begin - x = randn(ComplexF32, 10, 10) - x_ra = Reactant.to_rarray(x) - @test @jit(sum(abs2, x_ra)) ≈ sum(abs2, x) - end + @test isapprox(f(y), ComplexF32(2.0 - 1.0im)) +end - @testset "create complex numbers" begin - x = randn(ComplexF32) - x_ra = Reactant.to_rarray(x; track_numbers=true) - @test @jit(Complex(x_ra)) == x_ra +@testset "complex reduction" begin + x = randn(ComplexF32, 10, 10) + x_ra = Reactant.to_rarray(x) + @test @jit(sum(abs2, x_ra)) ≈ sum(abs2, x) +end - x = randn(Float32) - y = randn(Float64) - x_ra = Reactant.to_rarray(x; track_numbers=true) - y_ra = Reactant.to_rarray(y; track_numbers=true) - @test @jit(Complex(x_ra, y_ra)) == Complex(x, y) - @test @jit(Complex(x_ra, y)) == Complex(x, y) - @test @jit(Complex(x, y_ra)) == Complex(x, y) - @test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0)) - end +@testset "create complex numbers" begin + x = randn(ComplexF32) + x_ra = Reactant.to_rarray(x; track_numbers=true) + @test @jit(Complex(x_ra)) == x_ra + + x = randn(Float32) + y = randn(Float64) + x_ra = Reactant.to_rarray(x; track_numbers=true) + y_ra = Reactant.to_rarray(y; track_numbers=true) + @test @jit(Complex(x_ra, y_ra)) == Complex(x, y) skip = RunningOnTPU + @test @jit(Complex(x_ra, y)) == Complex(x, y) skip = RunningOnTPU + @test @jit(Complex(x, y_ra)) == Complex(x, y) skip = RunningOnTPU + @test @jit(Complex(x_ra)) == Complex(x) == @jit(Complex(x_ra, 0)) end diff --git a/test/indexing.jl b/test/indexing.jl index 34106b4070..9e4f24b4d5 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -288,22 +288,20 @@ function issue_617(outf, fr, pr, I) return outf end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "issue #617" begin - N, M = 4, 6 +@testset "issue #617" begin + N, M = 4, 6 - f = rand(ComplexF64, N, N) - p = rand(ComplexF64, N * N) - I = 1:(N^2) - out = rand(ComplexF64, M, M) + f = rand(ComplexF32, N, N) + p = rand(ComplexF32, N * N) + I = 1:(N^2) + out = rand(ComplexF32, M, M) - fr = Reactant.to_rarray(f) - pr = Reactant.to_rarray(p) - outr = Reactant.to_rarray(out) - Ir = Reactant.to_rarray(I) + fr = Reactant.to_rarray(f) + pr = Reactant.to_rarray(p) + outr = Reactant.to_rarray(out) + Ir = Reactant.to_rarray(I) - @test @jit(issue_617(outr, fr, pr, Ir)) ≈ issue_617(out, f, p, I) - end + @test @jit(issue_617(outr, fr, pr, Ir)) ≈ issue_617(out, f, p, I) end function scalar_setindex(x, idx, val) diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index cd3f4ba14b..a7848065ca 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -200,8 +200,10 @@ end oA = collect(Float64, 1:1:64) A = Reactant.to_rarray(oA) B = ConcreteRNumber(3.1) - @jit searchsorted!(A, B) - @test all(Array(A) .≈ 311) + @test begin + @jit searchsorted!(A, B) + all(Array(A) .≈ 311) + end broken = contains(string(Reactant.devices()[1]), "TPU") end function convert_mul_kernel!(Gu, w::FT) where {FT} diff --git a/test/integration/fft.jl b/test/integration/fft.jl index 96e11126ae..eafbef4c41 100644 --- a/test/integration/fft.jl +++ b/test/integration/fft.jl @@ -36,12 +36,12 @@ using FFTW, Reactant, Test end @testset "rfft" begin - x = rand(2, 2, 3, 4) + x = rand(Float32, 2, 2, 3, 4) x_ra = Reactant.to_rarray(x) @test_throws AssertionError @jit(rfft(x_ra)) # TODO: support this - x = rand(2, 3, 4) + x = rand(Float32, 2, 3, 4) x_ra = Reactant.to_rarray(x) @test @jit(rfft(x_ra)) ≈ rfft(x) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 49e1d1d76d..9dcf0be903 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -1,5 +1,7 @@ using LinearAlgebra, Reactant, Test +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") + function muladd2(A, x, b) C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2)) mul!(C, A, x) @@ -75,8 +77,10 @@ end @test @jit(muladd_5arg(A_ra, x_ra, b_ra)) ≈ muladd2(A, x, b) C_ra = similar(A_ra, Float32, size(A, 1), size(x, 2)) + C = similar(A, Float32, size(A, 1), size(x, 2)) @jit(mul!(C_ra, A_ra, x_ra)) - @test C_ra ≈ A * x + mul!(C, A, x) + @test C_ra ≈ C atol = 1e-3 rtol = 1e-2 end @testset "triu & tril" begin @@ -172,7 +176,7 @@ mul_symmetric(x) = Symmetric(x) * x end @testset "kron" begin - @testset for T in (Int64, Float64, ComplexF64) + @testset for T in (Int64, Float64, ComplexF32) @testset for (x_sz, y_sz) in [ ((3, 4), (2, 5)), ((3, 4), (2,)), ((3,), (2, 5)), ((3,), (5,)), ((10,), ()) ] @@ -315,6 +319,8 @@ end fn6(A, B) = B / transpose(A) @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + T == ComplexF64 && RunningOnTPU && continue + A = rand(T, 6, 6) B = rand(T, 6, 6) b = rand(T, 6) @@ -366,6 +372,8 @@ end @testset "LU Factorization" begin @testset "Un-batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + A = rand(T, 4, 4) A_ra = Reactant.to_rarray(A) @@ -375,13 +383,17 @@ end B = rand(T, 4, 3) B_ra = Reactant.to_rarray(B) - @test @jit(solve_with_lu(A_ra, b_ra)) ≈ solve_with_lu(A, b) - @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu(A, B) + @test @jit(solve_with_lu(A_ra, b_ra)) ≈ solve_with_lu(A, b) atol = 1e-4 rtol = + 1e-2 + @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu(A, B) atol = 1e-4 rtol = + 1e-2 end end @testset "Batched" begin @testset for T in (Float32, Float64, ComplexF32, ComplexF64) + (T == ComplexF64 || T == Float64) && RunningOnTPU && continue + A = rand(T, 4, 4, 3, 2) A_ra = Reactant.to_rarray(A) @@ -391,8 +403,10 @@ end B = rand(T, 4, 5, 3, 2) B_ra = Reactant.to_rarray(B) - @test @jit(solve_with_lu(A_ra, b_ra)) ≈ solve_with_lu_batched(A, b) - @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu_batched(A, B) + @test @jit(solve_with_lu(A_ra, b_ra)) ≈ solve_with_lu_batched(A, b) atol = 1e-4 rtol = + 1e-2 + @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu_batched(A, B) atol = 1e-4 rtol = + 1e-2 end end @@ -402,6 +416,7 @@ end A_ra = Reactant.to_rarray(A) B_ra = Reactant.to_rarray(B) - @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu_batched(A, B) + @test @jit(solve_with_lu(A_ra, B_ra)) ≈ solve_with_lu_batched(A, B) atol = 1e-4 rtol = + 1e-2 end end diff --git a/test/integration/special_functions.jl b/test/integration/special_functions.jl index c63f80b48d..d1c64f5533 100644 --- a/test/integration/special_functions.jl +++ b/test/integration/special_functions.jl @@ -1,5 +1,7 @@ using SpecialFunctions, Reactant +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") + macro ≈(a, b) return quote isapprox($a, $b; atol=1e-14) @@ -7,82 +9,94 @@ macro ≈(a, b) end @testset "gamma" begin - @test SpecialFunctions.gamma(0.5) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5))) - @test SpecialFunctions.gamma(2) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(2))) + @test SpecialFunctions.gamma(0.5) ≈ @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5))) atol = + 1e-5 rtol = 1e-3 + @test SpecialFunctions.gamma(Int32(2)) ≈ + @jit(SpecialFunctions.gamma(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "loggamma" begin @test SpecialFunctions.loggamma(0.5) ≈ - @jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) - @test abs(SpecialFunctions.loggamma(2)) < 1e-10 - @test abs(@jit(SpecialFunctions.loggamma(ConcreteRNumber(2)))) < 1e-10 + @jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) atol = 1e-5 rtol = 1e-3 + @test SpecialFunctions.loggamma(Int32(2)) ≈ + @jit(SpecialFunctions.loggamma(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "digamma" begin @test SpecialFunctions.digamma(0.5) ≈ @jit(SpecialFunctions.digamma(ConcreteRNumber(0.5))) - @test SpecialFunctions.digamma(2) ≈ @jit(SpecialFunctions.digamma(ConcreteRNumber(2))) + @test SpecialFunctions.digamma(Int32(2)) ≈ + @jit(SpecialFunctions.digamma(ConcreteRNumber(Int32(2)))) end @testset "trigamma" begin @test SpecialFunctions.trigamma(0.5) ≈ @jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5))) - @test SpecialFunctions.trigamma(2) ≈ @jit(SpecialFunctions.trigamma(ConcreteRNumber(2))) + @test SpecialFunctions.trigamma(Int32(2)) ≈ + @jit(SpecialFunctions.trigamma(ConcreteRNumber(Int32(2)))) end @testset "beta" begin @test SpecialFunctions.beta(0.5, 0.6) ≈ @jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) - @test SpecialFunctions.beta(2, 4) ≈ - @jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4))) + @test SpecialFunctions.beta(Int32(2), Int32(4)) ≈ + @jit(SpecialFunctions.beta(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4)))) end @testset "logbeta" begin @test SpecialFunctions.logbeta(0.5, 0.6) ≈ @jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) - @test SpecialFunctions.logbeta(2, 4) ≈ - @jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4))) + @test SpecialFunctions.logbeta(Int32(2), Int32(4)) ≈ @jit( + SpecialFunctions.logbeta(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4))) + ) end @testset "erf" begin @test SpecialFunctions.erf(0.5) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(0.5))) - @test SpecialFunctions.erf(2) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(2))) + @test SpecialFunctions.erf(Int32(2)) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "erf with 2 arguments" begin @test SpecialFunctions.erf(0.5, 0.6) ≈ @jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) - @test SpecialFunctions.erf(2, 4) ≈ - @jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4))) + @test SpecialFunctions.erf(Int32(2), Int32(4)) ≈ + @jit(SpecialFunctions.erf(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4)))) atol = + 1e-5 rtol = 1e-3 end @testset "erfc" begin @test SpecialFunctions.erfc(0.5) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5))) - @test SpecialFunctions.erfc(2) ≈ @jit(SpecialFunctions.erfc(ConcreteRNumber(2))) + @test SpecialFunctions.erfc(Int32(2)) ≈ + @jit(SpecialFunctions.erfc(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "logerf" begin @test SpecialFunctions.logerf(0.5, 0.6) ≈ @jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6))) - @test SpecialFunctions.logerf(2, 4) ≈ - @jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4))) + @test SpecialFunctions.logerf(Int32(2), Int32(4)) ≈ @jit( + SpecialFunctions.logerf(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4))) + ) atol = 1e-5 rtol = 1e-3 end @testset "erfcx" begin @test SpecialFunctions.erfcx(0.5) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5))) - @test SpecialFunctions.erfcx(2) ≈ @jit(SpecialFunctions.erfcx(ConcreteRNumber(2))) + @test SpecialFunctions.erfcx(Int32(2)) ≈ + @jit(SpecialFunctions.erfcx(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "logerfc" begin @test SpecialFunctions.logerfc(0.5) ≈ @jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5))) - @test SpecialFunctions.logerfc(2) ≈ @jit(SpecialFunctions.logerfc(ConcreteRNumber(2))) + @test SpecialFunctions.logerfc(Int32(2)) ≈ + @jit(SpecialFunctions.logerfc(ConcreteRNumber(Int32(2)))) end @testset "logerfcx" begin @test SpecialFunctions.logerfcx(0.5) ≈ @jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5))) - @test SpecialFunctions.logerfcx(2) ≈ @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2))) + @test SpecialFunctions.logerfcx(Int32(2)) ≈ + @jit(SpecialFunctions.logerfcx(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3 end @testset "loggamma1p" begin @@ -91,8 +105,10 @@ end end @testset "loggammadiv" begin - @test SpecialFunctions.loggammadiv(150, 20) ≈ - @jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20)) + @test SpecialFunctions.loggammadiv(Int32(150), Int32(20)) ≈ + @jit SpecialFunctions.loggammadiv( + ConcreteRNumber(Int32(150)), ConcreteRNumber(Int32(20)) + ) end @testset "zeta" begin diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 28dc94ff39..a4c35b4a7d 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -123,29 +123,34 @@ end end @testset "Batched Matrix Multiplication" begin - x = rand(Float32, 4, 3, 5) - y = rand(Float32, 3, 2, 5) + Reactant.with_config(; + convolution_precision=PrecisionConfig.HIGHEST, + dot_general_precision=PrecisionConfig.HIGHEST, + ) do + x = rand(Float32, 4, 3, 5) + y = rand(Float32, 3, 2, 5) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) - @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) + @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) - x = rand(Float32, 4, 3, 1) - y = rand(Float32, 3, 2, 5) + x = rand(Float32, 4, 3, 1) + y = rand(Float32, 3, 2, 5) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) - @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) + @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) - x = rand(Float32, 4, 3, 5) - y = rand(Float32, 3, 2, 1) + x = rand(Float32, 4, 3, 5) + y = rand(Float32, 3, 2, 1) - x_ra = Reactant.to_rarray(x) - y_ra = Reactant.to_rarray(y) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) - @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) + @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) + end end @testset "Constant Padding: NNlib.pad_constant" begin @@ -649,16 +654,18 @@ end ) in Iterators.product( (0, 2), (1, 2), (1,), (1,) ) - conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups) + Reactant.with_config(; convolution_precision=PrecisionConfig.HIGHEST) do + conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups) - output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size) - dy = randn(Float32, output_size) - dy_reactant = Reactant.to_rarray(dy) + output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size) + dy = randn(Float32, output_size) + dy_reactant = Reactant.to_rarray(dy) - @test @jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims)) ≈ - NNlib.∇conv_data(dy, w, conv_dims) - @test @jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈ - NNlib.∇conv_filter(x, dy, conv_dims) + @test @jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims)) ≈ + NNlib.∇conv_data(dy, w, conv_dims) + @test @jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims)) ≈ + NNlib.∇conv_filter(x, dy, conv_dims) + end end end @@ -704,12 +711,12 @@ end end @testset "Pixel shuffle" begin - x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] + x = Int32[10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] x_ra = Reactant.to_rarray(x) @test @jit(NNlib.pixel_shuffle(x_ra, 2)) ≈ NNlib.pixel_shuffle(x, 2) - y = [i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1] + y = Int32[i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1] y_ra = Reactant.to_rarray(y) @test @jit(NNlib.pixel_shuffle(y_ra, 2)) ≈ NNlib.pixel_shuffle(y, 2) diff --git a/test/ops.jl b/test/ops.jl index 5e871254f0..6be96585d8 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -3,6 +3,9 @@ using Reactant: Ops using LinearAlgebra using SpecialFunctions: SpecialFunctions +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") +const RunningOnAppleX86 = Sys.isapple() && Sys.ARCH === :x86_64 + @testset "abs" begin x = Reactant.to_rarray([1, -1]) @test [1, 1] ≈ @jit Ops.abs(x) @@ -10,16 +13,14 @@ using SpecialFunctions: SpecialFunctions x = Reactant.to_rarray([1.0, -1.0]) @test [1.0, 1.0] ≈ @jit Ops.abs(x) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([ - 3.0+4im -3.0+4im - 3.0-4im -3.0-4im - ]) - @test [ - 5.0 5.0 - 5.0 5.0 - ] ≈ @jit Ops.abs(x) - end + x = Reactant.to_rarray(ComplexF32[ + 3.0+4im -3.0+4im + 3.0-4im -3.0-4im + ]) + @test [ + 5.0 5.0 + 5.0 5.0 + ] ≈ @jit(Ops.abs(x)) end @testset "add" begin @@ -35,13 +36,11 @@ end b = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]) @test Array(a) .+ Array(b) ≈ @jit Ops.add(a, b) - if !contains(string(Reactant.devices()[1]), "TPU") - a = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) - b = Reactant.to_rarray([ - 9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im - ]) - @test Array(a) .+ Array(b) ≈ @jit Ops.add(a, b) - end + a = Reactant.to_rarray(ComplexF32[1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + b = Reactant.to_rarray( + ComplexF32[9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im] + ) + @test Array(a) .+ Array(b) ≈ @jit(Ops.add(a, b)) end @testset "after_all" begin @@ -99,18 +98,15 @@ end @test cholesky(Array(x)).U ≈ @jit g1(x) @test transpose(cholesky(Array(x)).U) ≈ @jit g2(x) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray( - [ - 10.0+0.0im 2.0-3.0im 3.0-4.0im - 2.0+3.0im 5.0+0.0im 3.0-2.0im - 3.0+4.0im 3.0+2.0im 9.0+0.0im - ], - ) - - @test cholesky(Array(x)).U ≈ @jit g1(x) - @test adjoint(cholesky(Array(x)).U) ≈ @jit g2(x) - end + x = Reactant.to_rarray( + ComplexF32[ + 10.0+0.0im 2.0-3.0im 3.0-4.0im + 2.0+3.0im 5.0+0.0im 3.0-2.0im + 3.0+4.0im 3.0+2.0im 9.0+0.0im + ], + ) + @test cholesky(Array(x)).U ≈ @jit(g1(x)) + @test adjoint(cholesky(Array(x)).U) ≈ @jit(g2(x)) end @testset "clamp" begin @@ -146,16 +142,15 @@ end end end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "complex" begin - x = Reactant.to_rarray(1.1; track_numbers=true) - y = Reactant.to_rarray(2.2; track_numbers=true) - @test 1.1 + 2.2im ≈ @jit Ops.complex(x, y) +@testset "complex" begin + x = Reactant.to_rarray(1.1f0; track_numbers=true) + y = Reactant.to_rarray(2.2f0; track_numbers=true) + @test ComplexF32(1.1 + 2.2im) ≈ @jit(Ops.complex(x, y)) - x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4]) - y = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8]) - @test [1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] ≈ @jit Ops.complex(x, y) - end + x = Reactant.to_rarray([1.1f0, 2.2f0, 3.3f0, 4.4f0]) + y = Reactant.to_rarray([5.5f0, 6.6f0, -7.7f0, -8.8f0]) + @test ComplexF32[1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] ≈ + @jit(Ops.complex(x, y)) end @testset "constant" begin @@ -172,21 +167,22 @@ end @testset "cosine" begin # it crashes in apple x86_64 and it's a deprecated platform so we don't need to care a lot... - if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) - @test [1, 0, -1, 0, 1] ≈ @jit Ops.cosine(x) - - x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) - @test [1.0, 0.0, -1.0, 0.0, 1.0] ≈ @jit Ops.cosine(x) - - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([ - 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im - ]) - @test [1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im] ≈ - @jit Ops.cosine(x) - end - end + x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) + @test [1, 0, -1, 0, 1] ≈ @jit(Ops.cosine(x)) broken = RunningOnAppleX86 + + x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) + @test [1.0, 0.0, -1.0, 0.0, 1.0] ≈ @jit(Ops.cosine(x)) broken = RunningOnAppleX86 + + @test ComplexF32[1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im, 1.0 + 0.0im] ≈ + @jit( + Ops.cosine( + Reactant.to_rarray( + ComplexF32[ + 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im + ], + ), + ) + ) skip = RunningOnAppleX86 end @testset "count_leading_zeros" begin @@ -228,13 +224,10 @@ end ) for (a, b) in [ - ([1, 2, 3, 4], [5, 6, -7, -8]), + (Int32[1, 2, 3, 4], Int32[5, 6, -7, -8]), ([1.0, 2.0, 3.0, 4.0], [5.0, 6.0, -7.0, -8.0]), - ([1.0, 2.0im, 3.0, 4.0im], [5.0, 6.0im, -7.0im, -8.0]), + (ComplexF32[1.0, 2.0im, 3.0, 4.0im], ComplexF32[5.0, 6.0im, -7.0im, -8.0]), ] - if contains(string(Reactant.devices()[1]), "TPU") - continue - end a = Reactant.to_rarray(a) b = Reactant.to_rarray(b) # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation @@ -244,58 +237,49 @@ end @test a .* b ≈ @jit fouter_batch1(a, b) end - if !contains(string(Reactant.devices()[1]), "TPU") - a = Reactant.to_rarray([1 2; 3 4]) - b = Reactant.to_rarray([5 6; -7 -8]) - @test Array(a)' * Array(b) == @jit f1(a, b) - end + a = Reactant.to_rarray(Int32[1 2; 3 4]) + b = Reactant.to_rarray(Int32[5 6; -7 -8]) + @test Array(a)' * Array(b) == @jit(f1(a, b)) end @testset "exponential" begin - x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) - @test exp.(Array(x)) ≈ @jit Ops.exponential(x) + x = [1.0, 2.0, 3.0, 4.0] + @test exp.(x) ≈ @jit Ops.exponential(Reactant.to_rarray(x)) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) - @test exp.(Array(x)) ≈ @jit Ops.exponential(x) - end + x = ComplexF32[1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im] + @test exp.(x) ≈ @jit(Ops.exponential(Reactant.to_rarray(x))) skip = RunningOnAppleX86 end @testset "exponential_minus_one" begin - x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) - @test expm1.(Array(x)) ≈ @jit Ops.exponential_minus_one(x) + x = [1.0, 2.0, 3.0, 4.0] + @test expm1.(x) ≈ @jit Ops.exponential_minus_one(Reactant.to_rarray(x)) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im]) - @test expm1.(Array(x)) ≈ @jit Ops.exponential_minus_one(x) - end + x = ComplexF32[1.0 + 2.0im, 3.0 + 4.0im, 5.0 + 6.0im, 7.0 + 8.0im] + @test expm1.(x) ≈ @jit(Ops.exponential_minus_one(Reactant.to_rarray(x))) skip = + RunningOnAppleX86 end @testset "fft" begin grfft(x) = Ops.fft(x; type="RFFT", length=[4]) gfft(x) = Ops.fft(x; type="FFT", length=[4]) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0, 1.0, 1.0, 1.0]) - @test ComplexF64[4.0, 0.0, 0.0] ≈ @jit grfft(x) + x = Reactant.to_rarray(Float32[1.0, 1.0, 1.0, 1.0]) + @test ComplexF32[4.0, 0.0, 0.0] ≈ @jit(grfft(x)) - x = Reactant.to_rarray([0.0, 1.0, 0.0, -1.0]) - @test ComplexF64[0.0, -2.0im, 0.0] ≈ @jit grfft(x) + x = Reactant.to_rarray(Float32[0.0, 1.0, 0.0, -1.0]) + @test ComplexF32[0.0, -2.0im, 0.0] ≈ @jit(grfft(x)) - x = Reactant.to_rarray([1.0, -1.0, 1.0, -1.0]) - @test ComplexF64[0.0, 0.0, 4.0] ≈ @jit grfft(x) + x = Reactant.to_rarray(Float32[1.0, -1.0, 1.0, -1.0]) + @test ComplexF32[0.0, 0.0, 4.0] ≈ @jit(grfft(x)) - x = Reactant.to_rarray(ComplexF64[1.0, 1.0, 1.0, 1.0]) - @test ComplexF64[4.0, 0.0, 0.0, 0.0] ≈ @jit gfft(x) + x = Reactant.to_rarray(ComplexF32[1.0, 1.0, 1.0, 1.0]) + @test ComplexF32[4.0, 0.0, 0.0, 0.0] ≈ @jit(gfft(x)) - x = Reactant.to_rarray(ComplexF64[0.0, 1.0, 0.0, -1.0]) - @test ComplexF64[0.0, -2.0im, 0.0, 2.0im] ≈ @jit gfft(x) + x = Reactant.to_rarray(ComplexF32[0.0, 1.0, 0.0, -1.0]) + @test ComplexF32[0.0, -2.0im, 0.0, 2.0im] ≈ @jit(gfft(x)) - x = Reactant.to_rarray(ComplexF64[1.0, -1.0, 1.0, -1.0]) - @test ComplexF64[0.0, 0.0, 4.0, 0.0] ≈ @jit gfft(x) - end + x = Reactant.to_rarray(ComplexF32[1.0, -1.0, 1.0, -1.0]) + @test ComplexF32[0.0, 0.0, 4.0, 0.0] ≈ @jit(gfft(x)) # TODO test with complex numbers and inverse FFT end @@ -317,11 +301,9 @@ end end end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "imag" begin - x = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) - @test [2.2, 4.4, 6.6, 8.8] ≈ @jit Ops.imag(x) - end +@testset "imag" begin + x = Reactant.to_rarray(ComplexF32[1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + @test Float32[2.2, 4.4, 6.6, 8.8] ≈ @jit(Ops.imag(x)) end @testset "iota" begin @@ -351,20 +333,16 @@ end x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test log.(Array(x)) ≈ @jit Ops.log(x) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) - @test log.(Array(x)) ≈ @jit Ops.log(x) - end + x = Reactant.to_rarray(ComplexF32[1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) + @test log.(Array(x)) ≈ @jit(Ops.log(x)) end @testset "log_plus_one" begin x = Reactant.to_rarray([1.0, 2.0, 3.0, 4.0]) @test log.(Array(x)) ≈ @jit Ops.log(x) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) - @test log.(Array(x)) ≈ @jit Ops.log(x) - end + x = Reactant.to_rarray(ComplexF32[1.0 + 0.0im, 2.0 + 0.0im, -3.0 + 0.0im, -4.0 + 0.0im]) + @test log.(Array(x)) ≈ @jit(Ops.log(x)) end @testset "logistic" begin @@ -435,10 +413,8 @@ end x = Reactant.to_rarray([-1.0, 0.0, 1.0, 10.0]) @test [1.0, 0.0, -1.0, -10.0] ≈ @jit Ops.negate(x) - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([-1.0 + 2im, 0.0 - 3im, 1.0 + 4im, 10.0 - 5im]) - @test [1.0 - 2im, 0.0 + 3im, -1.0 - 4im, -10.0 + 5im] ≈ @jit Ops.negate(x) - end + x = Reactant.to_rarray(ComplexF32[-1.0 + 2im, 0.0 - 3im, 1.0 + 4im, 10.0 - 5im]) + @test [1.0 - 2im, 0.0 + 3im, -1.0 - 4im, -10.0 + 5im] ≈ @jit(Ops.negate(x)) end @testset "not" begin @@ -508,19 +484,14 @@ end p = Reactant.to_rarray([0, 1, 2, 3]) @test Array(x) .^ Array(p) == @jit Ops.power(x, p) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im]) - p = Reactant.to_rarray([0.0 + 0.0im, 1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 0.0im]) - @test Array(x) .^ Array(p) ≈ @jit Ops.power(x, p) - end + x = Reactant.to_rarray(ComplexF32[0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im, 0.0 + 1.0im]) + p = Reactant.to_rarray(ComplexF32[0.0 + 0.0im, 1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 0.0im]) + @test Array(x) .^ Array(p) ≈ @jit(Ops.power(x, p)) end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "real" begin - x = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) - @test [1.1, 3.3, 5.5, 7.7] ≈ @jit Ops.real(x) - end +@testset "real" begin + x = Reactant.to_rarray(ComplexF32[1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im]) + @test [1.1, 3.3, 5.5, 7.7] ≈ @jit(Ops.real(x)) end @testset "recv" begin end @@ -560,52 +531,52 @@ end @test [2 1; 4 3] == @jit g2(x) end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "rng_bit_generator" begin - genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4]) - genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4]) - genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4]) - genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4]) - genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4]) - - @testset for (alg, sz) in - [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] - seed = Reactant.to_rarray(zeros(UInt64, sz)) - - res = @jit genInt32(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Int32,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genInt64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Int64,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genUInt64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{UInt64,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genFloat32(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Float32,2} - @test size(res.output) == (2, 4) - - seed = res.output_state - res = @jit genFloat64(seed) - @test res.output_state !== seed - @test size(res.output_state) == (sz,) - @test res.output isa ConcreteRArray{Float64,2} - @test size(res.output) == (2, 4) - end +@testset "rng_bit_generator" begin + @testset for (alg, sz) in + [("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)] + seed = Reactant.to_rarray(zeros(UInt64, sz)) + + res = @jit Ops.rng_bit_generator(Int32, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int32,2} + @test size(res.output) == (2, 4) + seed = res.output_state + + res = @jit Ops.rng_bit_generator(Int64, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Int64,2} + @test size(res.output) == (2, 4) + seed = res.output_state + + res = @jit Ops.rng_bit_generator(UInt64, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{UInt64,2} + @test size(res.output) == (2, 4) + seed = res.output_state + + res = @jit Ops.rng_bit_generator(Float32, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float32,2} + @test size(res.output) == (2, 4) + seed = res.output_state + + res = @jit Ops.rng_bit_generator(Float64, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float64,2} + @test size(res.output) == (2, 4) + seed = res.output_state + + res = @jit Ops.rng_bit_generator(Float32, seed, [2, 4]; algorithm=alg) + @test res.output_state !== seed + @test size(res.output_state) == (sz,) + @test res.output isa ConcreteRArray{Float32,2} + @test size(res.output) == (2, 4) + seed = res.output_state end end @@ -623,11 +594,8 @@ end x = Reactant.to_rarray([1.0 4.0; 9.0 25.0]) @test 1 ./ sqrt.(Array(x)) ≈ @jit Ops.rsqrt(x) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0+1im 4.0+2im; 9.0+3im 25.0+4im]) - @test 1 ./ sqrt.(Array(x)) ≈ @jit Ops.rsqrt(x) - end + x = Reactant.to_rarray(ComplexF32[1.0+1im 4.0+2im; 9.0+3im 25.0+4im]) + @test 1 ./ sqrt.(Array(x)) ≈ @jit(Ops.rsqrt(x)) end @testset "select" begin @@ -685,37 +653,33 @@ end x = Reactant.to_rarray([Inf, -Inf, NaN, -NaN, -1.0, -0.0, +0.0, 1.0]) @test [1.0, -1.0, NaN, NaN, -1.0, -0.0, 0.0, 1.0] ≈ @jit(Ops.sign(x)) nans = true - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([ + x = Reactant.to_rarray( + ComplexF32[ NaN + 1.0im, 1.0 + NaN, 0.0 + 0.0im, -1.0 + 2.0im, 0.0 - 3.0im, 1.0 + 4.0im - ]) - @test [ - NaN + NaN * im, - NaN + NaN * im, - 0.0 + 0.0im, - -0.4472135954999579 + 0.8944271909999159im, - 0.0 - 1.0im, - 0.24253562503633297 + 0.9701425001453319im, - ] ≈ @jit(Ops.sign(x)) nans = true - end + ], + ) + @test ComplexF32[ + NaN + NaN * im, + NaN + NaN * im, + 0.0 + 0.0im, + -0.4472135954999579 + 0.8944271909999159im, + 0.0 - 1.0im, + 0.24253562503633297 + 0.9701425001453319im, + ] ≈ @jit(Ops.sign(x)) nans = true end @testset "sine" begin - if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) - @test [0, 1, 0, -1, 0] ≈ @jit Ops.sine(x) - - x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) - @test [0.0, 1.0, 0.0, -1.0, 0.0] ≈ @jit Ops.sine(x) - - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([ - 0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im - ]) - @test [0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im] ≈ - @jit Ops.sine(x) - end - end + x = Reactant.to_rarray([0, π / 2, π, 3π / 2, 2π]) + @test [0, 1, 0, -1, 0] ≈ @jit(Ops.sine(x)) broken = RunningOnAppleX86 + + x = Reactant.to_rarray([0.0, π / 2, π, 3π / 2, 2π]) + @test [0.0, 1.0, 0.0, -1.0, 0.0] ≈ @jit(Ops.sine(x)) + + x = Reactant.to_rarray( + ComplexF32[0.0 + 0.0im, π / 2 + 0.0im, π + 0.0im, 3π / 2 + 0.0im, 2π + 0.0im] + ) + @test ComplexF32[0.0 + 0.0im, 1.0 + 0.0im, 0.0 + 0.0im, -1.0 + 0.0im, 0.0 + 0.0im] ≈ + @jit(Ops.sine(x)) broken = RunningOnAppleX86 end @testset "sort" begin @@ -741,11 +705,9 @@ end x = Reactant.to_rarray([1.0, 4.0, 9.0, 16.0]) @test [1.0, 2.0, 3.0, 4.0] ≈ @jit Ops.sqrt(x) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([1.0 + 0im, 0.0 + 1im]) - @test [1.0 + 0im, 1 / √2 * (1 + im)] ≈ @jit Ops.sqrt(x) - end + x = Reactant.to_rarray(ComplexF32[1.0 + 0im, 0.0 + 1im]) + @test ComplexF32[1.0 + 0im, 1 / √2 * (1 + im)] ≈ @jit(Ops.sqrt(x)) broken = + RunningOnAppleX86 end @testset "subtract" begin @@ -767,32 +729,26 @@ end end @testset "tan" begin - if !(Sys.isapple() && Sys.ARCH === :x86_64) - # TODO tan(π/2) is Inf but it returns 1.633123935319537e16 - x = Reactant.to_rarray([0, π / 4, π / 2, 3π / 4, π]) + # TODO: tan(π/2) is not inf + x = Reactant.to_rarray([0, π / 4, π / 3, 3π / 4, π]) - if !contains(string(Reactant.devices()[1]), "TPU") - @test [0.0, 1.0, 1.633123935319537e16, -1.0, 0.0] ≈ @jit Ops.tan(x) - end + @test [0.0, 1.0, 1.73205, -1.0, 0.0] ≈ @jit(Ops.tan(x)) atol = 1e-5 rtol = 1e-3 broken = + RunningOnAppleX86 - if !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray([ - 0.0 + 0.0im, π / 4 + 0.0im, π / 2 + 0.0im, 3π / 4 + 0.0im, π + 0.0im - ]) - @test ComplexF64[0.0, 1.0, 1.633123935319537e16, -1.0, 0.0] ≈ @jit Ops.tan(x) - end - end + x = Reactant.to_rarray( + ComplexF32[0.0 + 0.0im, π / 4 + 0.0im, π / 3 + 0.0im, 3π / 4 + 0.0im, π + 0.0im] + ) + @test ComplexF32[0.0, 1.0, 1.73205, -1.0, 0.0] ≈ @jit(Ops.tan(x)) atol = 1e-5 rtol = + 1e-3 broken = RunningOnAppleX86 end @testset "tanh" begin x = Reactant.to_rarray([-1.0, 0.0, 1.0]) @test [-0.7615941559557649, 0.0, 0.7615941559557649] ≈ @jit Ops.tanh(x) - if !(Sys.isapple() && Sys.ARCH === :x86_64) && - !contains(string(Reactant.devices()[1]), "TPU") - x = Reactant.to_rarray(ComplexF64[-1.0, 0.0, 1.0]) - @test ComplexF64[-0.7615941559557649, 0.0, 0.7615941559557649] ≈ @jit Ops.tanh(x) - end + x = Reactant.to_rarray(ComplexF32[-1.0, 0.0, 1.0]) + @test ComplexF32[-0.7615941559557649, 0.0, 0.7615941559557649] ≈ @jit(Ops.tanh(x)) skip = + RunningOnAppleX86 end @testset "transpose" begin @@ -868,11 +824,9 @@ end @test SpecialFunctions.besselix.(1, Array(x)) ≈ @jit Ops.bessel_i1e(x) end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "conj" begin - x = Reactant.to_rarray([-1.0 + 2im, 0.0 - 1im, 1.0 + 4im]) - @test conj(Array(x)) ≈ @jit Ops.conj(x) - end +@testset "conj" begin + x = Reactant.to_rarray(ComplexF32[-1.0 + 2im, 0.0 - 1im, 1.0 + 4im]) + @test conj(Array(x)) ≈ @jit(Ops.conj(x)) end @testset "cosh" begin @@ -883,10 +837,9 @@ end @testset "digamma" begin # small divergence between chlo.digamma and SpecialFunctions.digamma: # on <=0, chlo.digamma returns NaN, SpecialFunctions.digamma returns Inf - if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = Reactant.to_rarray([-1.0, 0.0, 1.0]) - @test [NaN, NaN, SpecialFunctions.digamma(1.0)] ≈ @jit(Ops.digamma(x)) nans = true - end + x = Reactant.to_rarray([-1.0, 0.0, 1.0]) + @test [NaN, NaN, SpecialFunctions.digamma(1.0)] ≈ @jit(Ops.digamma(x)) nans = true skip = + RunningOnAppleX86 end @testset "erf_inv" begin @@ -919,41 +872,34 @@ end @test [false, true, false, false, false, false, false] ≈ @jit Ops.is_pos_inf(x) end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "lgamma" begin - if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = Reactant.to_rarray([-1.0, 0.0, 1.0, 2.5]) - lgamma(x) = (SpecialFunctions.logabsgamma(x))[1] - @test lgamma.(Array(x)) ≈ @jit Ops.lgamma(x) - end - end +@testset "lgamma" begin + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 2.5]) + lgamma(x) = (SpecialFunctions.logabsgamma(x))[1] + @test lgamma.(Array(x)) ≈ @jit(Ops.lgamma(x)) atol = 1e-5 rtol = 1e-3 skip = + RunningOnAppleX86 end -if !contains(string(Reactant.devices()[1]), "TPU") - @testset "next_after" begin - x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5, 1e18, 1e18, 3e-9, 3e-9]) - y = Reactant.to_rarray([-2.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1e19, 0, 1]) - @test [ - prevfloat(-1.0), - 0.0, - 1.0, - nextfloat(1.0), - nextfloat(2.5), - prevfloat(1e18), - nextfloat(1e18), - prevfloat(3e-9), - nextfloat(3e-9), - ] == @jit Ops.next_after(x, y) - end +@testset "next_after" begin + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5, 1e18, 1e18, 3e-9, 3e-9]) + y = Reactant.to_rarray([-2.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1e19, 0, 1]) + @test [ + prevfloat(-1.0), + 0.0, + 1.0, + nextfloat(1.0), + nextfloat(2.5), + prevfloat(1e18), + nextfloat(1e18), + prevfloat(3e-9), + nextfloat(3e-9), + ] == @jit(Ops.next_after(x, y)) skip = RunningOnTPU end @testset "polygamma" begin - if !(Sys.isapple() && Sys.ARCH === :x86_64) - x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5]) - m = Reactant.to_rarray([3.0, 3.0, 2.0, 3.0, 4.0]) - @test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x)) ≈ - @jit Ops.polygamma(m, x) - end + x = Reactant.to_rarray([-1.0, 0.0, 1.0, 1.0, 2.5]) + m = Reactant.to_rarray([3.0, 3.0, 2.0, 3.0, 4.0]) + @test SpecialFunctions.polygamma.(Int.(Array(m)), Array(x)) ≈ @jit(Ops.polygamma(m, x)) broken = + RunningOnAppleX86 end @testset "sinh" begin diff --git a/test/optimize_comm.jl b/test/optimize_comm.jl index 1b081d72a7..332f301416 100644 --- a/test/optimize_comm.jl +++ b/test/optimize_comm.jl @@ -75,7 +75,6 @@ if length(addressable_devices) ≥ 8 ry = Reactant.to_rarray(y; sharding) hlo = repr(@code_xla shardy_passes = :to_mhlo_shardings dus(rx, ry)) - println(hlo) @test !contains(hlo, "all-to-all") @test !contains(hlo, "all-gather") @test contains(hlo, "collective-permute") diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..d56832cff9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,10 +24,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Compile" include("compile.jl") @safetestset "IR" include("ir.jl") @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Indexing" include("indexing.jl") if !Sys.isapple() @safetestset "Custom Number Types" include("custom_number_types.jl") @@ -40,24 +40,30 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" - @safetestset "CUDA" include("integration/cuda.jl") - @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") + # @safetestset "CUDA" include("integration/cuda.jl") + # @safetestset "KernelAbstractions" include("integration/kernelabstractions.jl") @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "OffsetArrays" include("integration/offsetarrays.jl") - @safetestset "OneHotArrays" include("integration/onehotarrays.jl") - @safetestset "AbstractFFTs" include("integration/fft.jl") + @info "Linear Algebra tests finished" + # @safetestset "OffsetArrays" include("integration/offsetarrays.jl") + # @safetestset "OneHotArrays" include("integration/onehotarrays.jl") + # @safetestset "AbstractFFTs" include("integration/fft.jl") @safetestset "SpecialFunctions" include("integration/special_functions.jl") - @safetestset "Random" include("integration/random.jl") - @safetestset "Python" include("integration/python.jl") - @safetestset "Optimisers" include("integration/optimisers.jl") + @info "SpecialFunctions tests finished" + # @safetestset "Random" include("integration/random.jl") + # @safetestset "Python" include("integration/python.jl") + # @safetestset "Optimisers" include("integration/optimisers.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @safetestset "NNlib Primitives" include("nn/nnlib.jl") @safetestset "Flux.jl Integration" include("nn/flux.jl") + @info "Flux.jl Integration tests finished" if Sys.islinux() - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") @safetestset "Lux Integration" include("nn/lux.jl") + @info "Lux Integration tests finished" + @safetestset "LuxLib Primitives" include("nn/luxlib.jl") # XXX: TPU takes too long + @info "LuxLib Primitives tests finished" end + @safetestset "NNlib Primitives" include("nn/nnlib.jl") + @info "NNlib Primitives tests finished" end end diff --git a/test/sharding.jl b/test/sharding.jl index 668c597ec1..8fcfbdbd18 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -1,6 +1,7 @@ using Reactant, Test const addressable_devices = Reactant.addressable_devices() +const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU") function fn_test1(x) y = x .+ x @@ -461,14 +462,13 @@ end end @testset "Compile-Only with More Devices" begin - if !contains(string(Reactant.devices()[1]), "TPU") - mesh = Sharding.Mesh(zeros(Int64, 2, 4), (:x, :y)) + mesh = Sharding.Mesh(zeros(Int64, 2, 4), (:x, :y)) + @test begin x_ra = Reactant.to_rarray( rand(Float32, 32, 32); sharding=Sharding.NamedSharding(mesh, (:x, :y)) ) - hlo = @code_xla sum(x_ra) - @test contains(repr(hlo), "num_partitions=8") - end + contains(repr(hlo), "num_partitions=8") + end skip = RunningOnTPU end