Skip to content

Commit b9ea1ae

Browse files
committed
test: more fixes
1 parent 99069e5 commit b9ea1ae

File tree

3 files changed

+42
-48
lines changed

3 files changed

+42
-48
lines changed

test/basic.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -824,16 +824,19 @@ end
824824
end
825825

826826
@testset "signbit" begin
827-
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
828-
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken = RunningOnTPU
827+
@testset "$(typeof(x))" for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
828+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken =
829+
RunningOnTPU && eltype(x) == Float64
829830
end
830831
end
831832

832833
@testset "copysign" begin
833-
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
834+
@testset "$(typeof(a)) $(typeof(b))" for a in (-3.14, -2, 0.0, 2.71, 42),
835+
b in (-7, -0.57, -0.0, 1, 3.14)
834836
# Make sure also the return type is correct
835837
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ===
836-
copysign(a, b) broken = RunningOnTPU
838+
copysign(a, b) broken =
839+
RunningOnTPU && (eltype(a) == Float64 || eltype(b) == Float64)
837840
end
838841
end
839842

@@ -930,10 +933,10 @@ end
930933
end
931934

932935
@testset "@code_xla" begin
933-
x_ra = Reactant.to_rarray(ones(4))
936+
x_ra = Reactant.to_rarray(ones(Float32, 4))
934937
hlo = repr(@code_xla(sin.(x_ra)))
935-
@test contains(hlo, "HloModule") broken = RunningOnTPU
936-
@test contains(hlo, "sine") broken = RunningOnTPU
938+
@test contains(hlo, "HloModule")
939+
@test contains(hlo, "sine")
937940
end
938941

939942
@testset "Raise keyword" begin
@@ -1125,7 +1128,7 @@ end
11251128
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
11261129
@compile sin.(Reactant.to_rarray(Float32[1.0]))
11271130
for mod in readdir(dir; join=true)
1128-
@test contains(read(mod, String), "hlo.sine") broken = RunningOnTPU
1131+
@test contains(read(mod, String), "hlo.sine")
11291132
end
11301133
end
11311134

@@ -1135,7 +1138,7 @@ end
11351138
@compile exp.(Reactant.to_rarray(Float32[1.0]))
11361139
# Make sure we don't save anything to file when compilation is
11371140
# successful and `DUMP_MLIR_ALWAYS=false`.
1138-
@test isempty(readdir(dir; join=true)) broken = RunningOnTPU
1141+
@test isempty(readdir(dir; join=true))
11391142
end
11401143

11411144
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
@@ -1267,7 +1270,7 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12671270

12681271
@testset "accumulate" begin
12691272
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1270-
accumulate(accum_fn, a; init=0.0f0)
1273+
accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU
12711274

12721275
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
12731276
accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU

test/ops.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ const RunningOnAppleX86 = Sys.isapple() && Sys.ARCH === :x86_64
1313
x = Reactant.to_rarray([1.0, -1.0])
1414
@test [1.0, 1.0] @jit Ops.abs(x)
1515

16-
x = Reactant.to_rarray([
17-
3.0+4im -3.0+4im
18-
3.0-4im -3.0-4im
19-
])
2016
@test begin
17+
x = Reactant.to_rarray([
18+
3.0+4im -3.0+4im
19+
3.0-4im -3.0-4im
20+
])
2121
[
2222
5.0 5.0
2323
5.0 5.0
@@ -38,11 +38,13 @@ end
3838
b = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8])
3939
@test Array(a) .+ Array(b) @jit Ops.add(a, b)
4040

41-
a = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im])
42-
b = Reactant.to_rarray([
43-
9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im
44-
])
45-
@test Array(a) .+ Array(b) @jit(Ops.add(a, b)) broken = RunningOnTPU
41+
@test begin
42+
a = Reactant.to_rarray([1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im, 7.7 + 8.8im])
43+
b = Reactant.to_rarray([
44+
9.9 + 10.10im, 11.11 + 12.12im, -13.13 + -14.14im, -15.15 + -16.16im
45+
])
46+
Array(a) .+ Array(b) @jit(Ops.add(a, b))
47+
end broken = RunningOnTPU
4648
end
4749

4850
@testset "after_all" begin
@@ -100,16 +102,16 @@ end
100102
@test cholesky(Array(x)).U @jit g1(x)
101103
@test transpose(cholesky(Array(x)).U) @jit g2(x)
102104

103-
x = Reactant.to_rarray(
104-
[
105-
10.0+0.0im 2.0-3.0im 3.0-4.0im
106-
2.0+3.0im 5.0+0.0im 3.0-2.0im
107-
3.0+4.0im 3.0+2.0im 9.0+0.0im
108-
],
109-
)
110-
111-
@test cholesky(Array(x)).U @jit(g1(x)) broken = RunningOnTPU
112-
@test adjoint(cholesky(Array(x)).U) @jit(g2(x)) broken = RunningOnTPU
105+
@test begin
106+
x = Reactant.to_rarray(
107+
[
108+
10.0+0.0im 2.0-3.0im 3.0-4.0im
109+
2.0+3.0im 5.0+0.0im 3.0-2.0im
110+
3.0+4.0im 3.0+2.0im 9.0+0.0im
111+
],
112+
)
113+
cholesky(Array(x)).U @jit(g1(x)) && adjoint(cholesky(Array(x)).U) @jit(g2(x))
114+
end broken = RunningOnTPU
113115
end
114116

115117
@testset "clamp" begin

test/runtests.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,18 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
1313
end
1414

1515
# @safetestset "Layout" include("layout.jl")
16-
# @info "Layout tests finished"
1716
# @safetestset "Tracing" include("tracing.jl")
18-
# @info "Tracing tests finished"
1917
@safetestset "Basic" include("basic.jl")
2018
@info "Basic tests finished"
2119
# @safetestset "Constructor" include("constructor.jl")
22-
# @info "Constructor tests finished"
23-
@safetestset "Autodiff" include("autodiff.jl")
24-
@info "Autodiff tests finished"
25-
@safetestset "Complex" include("complex.jl")
26-
@info "Complex tests finished"
27-
@safetestset "Broadcast" include("bcast.jl")
28-
@info "Broadcast tests finished"
29-
@safetestset "Struct" include("struct.jl")
30-
@info "Struct tests finished"
31-
@safetestset "Closure" include("closure.jl")
32-
@info "Closure tests finished"
33-
@safetestset "Compile" include("compile.jl")
34-
@info "Compile tests finished"
35-
@safetestset "IR" include("ir.jl")
36-
@info "IR tests finished"
37-
@safetestset "Buffer Donation" include("buffer_donation.jl")
38-
@info "Buffer Donation tests finished"
20+
# @safetestset "Autodiff" include("autodiff.jl")
21+
# @safetestset "Complex" include("complex.jl")
22+
# @safetestset "Broadcast" include("bcast.jl")
23+
# @safetestset "Struct" include("struct.jl")
24+
# @safetestset "Closure" include("closure.jl")
25+
# @safetestset "Compile" include("compile.jl")
26+
# @safetestset "IR" include("ir.jl")
27+
# @safetestset "Buffer Donation" include("buffer_donation.jl")
3928
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
4029
@info "Shortcuts to MLIR ops tests finished"
4130
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")

0 commit comments

Comments
 (0)