Skip to content

Commit 01d07f5

Browse files
committed
test: more fixes
1 parent 99069e5 commit 01d07f5

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

test/basic.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -824,16 +824,19 @@ end
824824
end
825825

826826
@testset "signbit" begin
827-
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
828-
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken = RunningOnTPU
827+
@testset "$(typeof(x))" for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
828+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x) broken =
829+
RunningOnTPU && eltype(x) == Float64
829830
end
830831
end
831832

832833
@testset "copysign" begin
833-
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
834+
@testset "$(typeof(a)) $(typeof(b))" for a in (-3.14, -2, 0.0, 2.71, 42),
835+
b in (-7, -0.57, -0.0, 1, 3.14)
834836
# Make sure also the return type is correct
835837
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ===
836-
copysign(a, b) broken = RunningOnTPU
838+
copysign(a, b) broken =
839+
RunningOnTPU && (eltype(a) == Float64 || eltype(b) == Float64)
837840
end
838841
end
839842

@@ -930,10 +933,10 @@ end
930933
end
931934

932935
@testset "@code_xla" begin
933-
x_ra = Reactant.to_rarray(ones(4))
936+
x_ra = Reactant.to_rarray(ones(Float32, 4))
934937
hlo = repr(@code_xla(sin.(x_ra)))
935-
@test contains(hlo, "HloModule") broken = RunningOnTPU
936-
@test contains(hlo, "sine") broken = RunningOnTPU
938+
@test contains(hlo, "HloModule")
939+
@test contains(hlo, "sine")
937940
end
938941

939942
@testset "Raise keyword" begin
@@ -1125,7 +1128,7 @@ end
11251128
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
11261129
@compile sin.(Reactant.to_rarray(Float32[1.0]))
11271130
for mod in readdir(dir; join=true)
1128-
@test contains(read(mod, String), "hlo.sine") broken = RunningOnTPU
1131+
@test contains(read(mod, String), "hlo.sine")
11291132
end
11301133
end
11311134

@@ -1135,7 +1138,7 @@ end
11351138
@compile exp.(Reactant.to_rarray(Float32[1.0]))
11361139
# Make sure we don't save anything to file when compilation is
11371140
# successful and `DUMP_MLIR_ALWAYS=false`.
1138-
@test isempty(readdir(dir; join=true)) broken = RunningOnTPU
1141+
@test isempty(readdir(dir; join=true))
11391142
end
11401143

11411144
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
@@ -1267,7 +1270,7 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12671270

12681271
@testset "accumulate" begin
12691272
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1270-
accumulate(accum_fn, a; init=0.0f0)
1273+
accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU
12711274

12721275
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
12731276
accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU

test/runtests.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,12 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
1313
end
1414

1515
# @safetestset "Layout" include("layout.jl")
16-
# @info "Layout tests finished"
1716
# @safetestset "Tracing" include("tracing.jl")
18-
# @info "Tracing tests finished"
1917
@safetestset "Basic" include("basic.jl")
2018
@info "Basic tests finished"
2119
# @safetestset "Constructor" include("constructor.jl")
22-
# @info "Constructor tests finished"
23-
@safetestset "Autodiff" include("autodiff.jl")
24-
@info "Autodiff tests finished"
25-
@safetestset "Complex" include("complex.jl")
26-
@info "Complex tests finished"
20+
# @safetestset "Autodiff" include("autodiff.jl")
21+
# @safetestset "Complex" include("complex.jl")
2722
@safetestset "Broadcast" include("bcast.jl")
2823
@info "Broadcast tests finished"
2924
@safetestset "Struct" include("struct.jl")

0 commit comments

Comments
 (0)