Skip to content

Commit 12d6e2e

Browse files
wsmosesavik-pal
andauthored
Reduce TPU errors (#1531)
* Reduce TPU errors * fix * fix * more * fix * fix * fix * even more complex * fix * fix * fix * more fx * more fixes * more fixups * chore: run formatter --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 2061166 commit 12d6e2e

File tree

4 files changed

+238
-180
lines changed

4 files changed

+238
-180
lines changed

test/basic.jl

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -794,8 +794,10 @@ end
794794
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
795795
@test @jit(isnan.(x)) == [false, true, false, false, true]
796796

797-
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
798-
@test @jit(isnan.(x)) == [false, true, false, false, true]
797+
if !contains(string(Reactant.devices()[1]), "TPU")
798+
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
799+
@test @jit(isnan.(x)) == [false, true, false, false, true]
800+
end
799801
end
800802

801803
@testset "isnan/isfinite" begin
@@ -818,9 +820,11 @@ end
818820
b = [6.6, -2.2, -8.8, 4.4, -10.1]
819821

820822
expected_mod = mod.(a, b)
821-
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
822-
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
823-
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
823+
if !contains(string(Reactant.devices()[1]), "TPU")
824+
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
825+
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
826+
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
827+
end
824828

825829
expected_rem = rem.(a, b)
826830
@test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
@@ -834,17 +838,22 @@ end
834838
end
835839
end
836840

837-
@testset "signbit" begin
838-
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
839-
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
841+
if !contains(string(Reactant.devices()[1]), "TPU")
842+
@testset "signbit" begin
843+
for x in (-4, -3.14, -0.0f0, 0.0, 0, 5, 6.28f0)
844+
@test @jit(signbit(ConcreteRNumber(x))) == signbit(x)
845+
end
840846
end
841847
end
842848

843-
@testset "copysign" begin
844-
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
845-
# Make sure also the return type is correct
846-
@test Reactant.to_number(@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))) ===
847-
copysign(a, b)
849+
if !contains(string(Reactant.devices()[1]), "TPU")
850+
@testset "copysign" begin
851+
for a in (-3.14, -2, 0.0, 2.71, 42), b in (-7, -0.57, -0.0, 1, 3.14)
852+
# Make sure also the return type is correct
853+
@test Reactant.to_number(
854+
@jit(copysign(ConcreteRNumber(a), ConcreteRNumber(b)))
855+
) === copysign(a, b)
856+
end
848857
end
849858
end
850859

@@ -940,11 +949,13 @@ end
940949
ra[:a] (2.7 * 2) * ones(4)
941950
end
942951

943-
@testset "@code_xla" begin
944-
x_ra = Reactant.to_rarray(ones(4))
945-
hlo = repr(@code_xla(sin.(x_ra)))
946-
@test contains(hlo, "HloModule")
947-
@test contains(hlo, "sine")
952+
if !contains(string(Reactant.devices()[1]), "TPU")
953+
@testset "@code_xla" begin
954+
x_ra = Reactant.to_rarray(ones(4))
955+
hlo = repr(@code_xla(sin.(x_ra)))
956+
@test contains(hlo, "HloModule")
957+
@test contains(hlo, "sine")
958+
end
948959
end
949960

950961
@testset "Raise keyword" begin
@@ -1129,30 +1140,32 @@ end
11291140
end
11301141
end
11311142

1132-
@testset "Dump MLIR modules" begin
1133-
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1134-
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
1143+
if !contains(string(Reactant.devices()[1]), "TPU")
1144+
@testset "Dump MLIR modules" begin
1145+
always_old = Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[]
1146+
dir_old = Reactant.MLIR.IR.DUMP_MLIR_DIR[]
1147+
1148+
mktempdir() do dir
1149+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1150+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1151+
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1152+
for mod in readdir(dir; join=true)
1153+
@test contains(read(mod, String), "hlo.sine")
1154+
end
1155+
end
11351156

1136-
mktempdir() do dir
1137-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
1138-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1139-
@compile sin.(Reactant.to_rarray(Float32[1.0]))
1140-
for mod in readdir(dir; join=true)
1141-
@test contains(read(mod, String), "hlo.sine")
1157+
mktempdir() do dir
1158+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1159+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1160+
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1161+
# Make sure we don't save anything to file when compilation is
1162+
# successful and `DUMP_MLIR_ALWAYS=false`.
1163+
@test isempty(readdir(dir; join=true))
11421164
end
1143-
end
11441165

1145-
mktempdir() do dir
1146-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = false
1147-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir
1148-
@compile exp.(Reactant.to_rarray(Float32[1.0]))
1149-
# Make sure we don't save anything to file when compilation is
1150-
# successful and `DUMP_MLIR_ALWAYS=false`.
1151-
@test isempty(readdir(dir; join=true))
1166+
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1167+
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
11521168
end
1153-
1154-
Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = always_old
1155-
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = dir_old
11561169
end
11571170

11581171
@testset "Allocator Stats" begin

test/indexing.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
y2 = @jit update_on_copy(x_concrete)
1717
@test x == y
1818
@test x_concrete == y_concrete
19-
@test y1 == y2
19+
@test y1 y2
2020

2121
# function update_inplace(x)
2222
# y = view(x, 1:2, 1:2, :)
@@ -288,20 +288,22 @@ function issue_617(outf, fr, pr, I)
288288
return outf
289289
end
290290

291-
@testset "issue #617" begin
292-
N, M = 4, 6
291+
if !contains(string(Reactant.devices()[1]), "TPU")
292+
@testset "issue #617" begin
293+
N, M = 4, 6
293294

294-
f = rand(ComplexF64, N, N)
295-
p = rand(ComplexF64, N * N)
296-
I = 1:(N^2)
297-
out = rand(ComplexF64, M, M)
295+
f = rand(ComplexF64, N, N)
296+
p = rand(ComplexF64, N * N)
297+
I = 1:(N^2)
298+
out = rand(ComplexF64, M, M)
298299

299-
fr = Reactant.to_rarray(f)
300-
pr = Reactant.to_rarray(p)
301-
outr = Reactant.to_rarray(out)
302-
Ir = Reactant.to_rarray(I)
300+
fr = Reactant.to_rarray(f)
301+
pr = Reactant.to_rarray(p)
302+
outr = Reactant.to_rarray(out)
303+
Ir = Reactant.to_rarray(I)
303304

304-
@test @jit(issue_617(outr, fr, pr, Ir)) issue_617(out, f, p, I)
305+
@test @jit(issue_617(outr, fr, pr, Ir)) issue_617(out, f, p, I)
306+
end
305307
end
306308

307309
function scalar_setindex(x, idx, val)

0 commit comments

Comments
 (0)