Skip to content

Commit b933b95

Browse files
wsmosesavik-pal
authored andcommitted
Even more tpu work
1 parent d0b6471 commit b933b95

File tree

3 files changed

+47
-43
lines changed

3 files changed

+47
-43
lines changed

src/TracedRArray.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,21 +1281,23 @@ function scan_impl!(
12811281

12821282
dims > ndims(input) && return copyto!(output, input)
12831283

1284-
if init === nothing
1285-
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1286-
op_in_T === Union{} && (op_in_T = T)
1284+
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
1285+
op_in_T === Union{} && (op_in_T = T)
12871286

1287+
if init === nothing
12881288
init = __default_init(T, op)
1289-
if typeof(init) != op_in_T
1290-
op_in_T = typeof(init)
1291-
input = typeof(init).(input)
1292-
end
12931289
else
1290+
initT = __default_init(T, op)
12941291
# TODO: fix this for TPUs
1295-
if contains(string(first(Reactant.devices())), "TPU")
1292+
if initT != init && contains(string(first(Reactant.devices())), "TPU")
12961293
throw(AssertionError("Currently, `init` is not supported on TPUs."))
12971294
end
12981295
end
1296+
1297+
if typeof(init) != op_in_T
1298+
op_in_T = typeof(init)
1299+
input = typeof(init).(input)
1300+
end
12991301
init = something(init) # unwrap Some
13001302
init = TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)
13011303

test/ops.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ end
560560
@test [2 1; 4 3] == @jit g2(x)
561561
end
562562

563+
if !contains(string(Reactant.devices()[1]), "TPU")
563564
@testset "rng_bit_generator" begin
564565
genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4])
565566
genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4])
@@ -606,6 +607,7 @@ end
606607
@test size(res.output) == (2, 4)
607608
end
608609
end
610+
end
609611

610612
@testset "round_nearest_afz" begin
611613
x = Reactant.to_rarray([-2.5, 0.4, 0.5, 0.6, 2.5])

test/sorting.jl

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,25 @@ using Reactant, Test, Random, StableRNGs
1111
srt_lt(x) = sort(x; lt=(a, b) -> a > b)
1212
srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b)
1313

14-
@test @jit(sort(x_ra)) == sort(x)
15-
@test @jit(srt_rev(x_ra)) == srt_rev(x)
16-
@test @jit(srt_lt(x_ra)) == srt_lt(x)
17-
@test @jit(srt_by(x_ra)) == srt_by(x)
18-
@test @jit(sortperm(x_ra)) == sortperm(x)
19-
@test @jit(srtperm_rev(x_ra)) == srtperm_rev(x)
20-
@test @jit(srtperm_lt(x_ra)) == srtperm_lt(x)
21-
@test @jit(srtperm_by(x_ra)) == srtperm_by(x)
14+
@test @jit(sort(x_ra)) sort(x)
15+
@test @jit(srt_rev(x_ra)) srt_rev(x)
16+
@test @jit(srt_lt(x_ra)) srt_lt(x)
17+
@test @jit(srt_by(x_ra)) srt_by(x)
18+
@test @jit(sortperm(x_ra)) sortperm(x)
19+
@test @jit(srtperm_rev(x_ra)) srtperm_rev(x)
20+
@test @jit(srtperm_lt(x_ra)) srtperm_lt(x)
21+
@test @jit(srtperm_by(x_ra)) srtperm_by(x)
2222

2323
x = rand(10)
2424
x_ra = Reactant.to_rarray(x)
2525
@jit sort!(x_ra)
26-
@test x_ra == sort(x)
26+
@test x_ra sort(x)
2727

2828
x = rand(10)
2929
x_ra = Reactant.to_rarray(x)
3030
ix = similar(x_ra, Int)
3131
@jit sortperm!(ix, x_ra)
32-
@test ix == sortperm(x)
32+
@test ix sortperm(x)
3333

3434
x = rand(10, 4, 3)
3535
x_ra = Reactant.to_rarray(x)
@@ -44,51 +44,51 @@ using Reactant, Test, Random, StableRNGs
4444
srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b)
4545

4646
@testset for d in 1:ndims(x)
47-
@test @jit(srt(x_ra, d)) == srt(x, d)
48-
@test @jit(srtperm(x_ra, d)) == srtperm(x, d)
49-
@test @jit(srt_rev(x_ra, d)) == srt_rev(x, d)
50-
@test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d)
51-
@test @jit(srt_by(x_ra, d)) == srt_by(x, d)
52-
@test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d)
53-
@test @jit(srt_lt(x_ra, d)) == srt_lt(x, d)
54-
@test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d)
47+
@test @jit(srt(x_ra, d)) srt(x, d)
48+
@test @jit(srtperm(x_ra, d)) srtperm(x, d)
49+
@test @jit(srt_rev(x_ra, d)) srt_rev(x, d)
50+
@test @jit(srtperm_rev(x_ra, d)) srtperm_rev(x, d)
51+
@test @jit(srt_by(x_ra, d)) srt_by(x, d)
52+
@test @jit(srtperm_by(x_ra, d)) srtperm_by(x, d)
53+
@test @jit(srt_lt(x_ra, d)) srt_lt(x, d)
54+
@test @jit(srtperm_lt(x_ra, d)) srtperm_lt(x, d)
5555
end
5656
end
5757

5858
@testset "partialsort & partialsortperm" begin
5959
x = randn(10)
6060
x_ra = Reactant.to_rarray(x)
6161

62-
@test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5)
63-
@test @jit(partialsort(x_ra, 1:5; rev=true)) == partialsort(x, 1:5; rev=true)
64-
@test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5)
65-
@test @jit(partialsortperm(x_ra, 1:5; rev=true)) == partialsortperm(x, 1:5; rev=true)
66-
@test @jit(partialsort(x_ra, 3:6)) == partialsort(x, 3:6)
67-
@test @jit(partialsort(x_ra, 3:6; rev=true)) == partialsort(x, 3:6; rev=true)
68-
@test @jit(partialsortperm(x_ra, 3:6)) == partialsortperm(x, 3:6)
69-
@test @jit(partialsortperm(x_ra, 3:6; rev=true)) == partialsortperm(x, 3:6; rev=true)
70-
@test @jit(partialsort(x_ra, 4)) == partialsort(x, 4)
71-
@test @jit(partialsort(x_ra, 4; rev=true)) == partialsort(x, 4; rev=true)
72-
@test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4)
73-
@test @jit(partialsortperm(x_ra, 4; rev=true)) == partialsortperm(x, 4; rev=true)
62+
@test @jit(partialsort(x_ra, 1:5)) partialsort(x, 1:5)
63+
@test @jit(partialsort(x_ra, 1:5; rev=true)) partialsort(x, 1:5; rev=true)
64+
@test @jit(partialsortperm(x_ra, 1:5)) partialsortperm(x, 1:5)
65+
@test @jit(partialsortperm(x_ra, 1:5; rev=true)) partialsortperm(x, 1:5; rev=true)
66+
@test @jit(partialsort(x_ra, 3:6)) partialsort(x, 3:6)
67+
@test @jit(partialsort(x_ra, 3:6; rev=true)) partialsort(x, 3:6; rev=true)
68+
@test @jit(partialsortperm(x_ra, 3:6)) partialsortperm(x, 3:6)
69+
@test @jit(partialsortperm(x_ra, 3:6; rev=true)) partialsortperm(x, 3:6; rev=true)
70+
@test @jit(partialsort(x_ra, 4)) partialsort(x, 4)
71+
@test @jit(partialsort(x_ra, 4; rev=true)) partialsort(x, 4; rev=true)
72+
@test @jit(partialsortperm(x_ra, 4)) partialsortperm(x, 4)
73+
@test @jit(partialsortperm(x_ra, 4; rev=true)) partialsortperm(x, 4; rev=true)
7474

7575
x = randn(10)
7676
x_ra = Reactant.to_rarray(x)
7777
@jit partialsort!(x_ra, 1:5)
7878
partialsort!(x, 1:5)
79-
@test Array(x_ra)[1:5] == x[1:5]
79+
@test Array(x_ra)[1:5] x[1:5]
8080

8181
x = randn(10)
8282
x_ra = Reactant.to_rarray(x)
8383
@jit partialsort!(x_ra, 3:5; rev=true)
8484
partialsort!(x, 3:5; rev=true)
85-
@test Array(x_ra)[3:5] == x[3:5]
85+
@test Array(x_ra)[3:5] x[3:5]
8686

8787
x = randn(10)
8888
x_ra = Reactant.to_rarray(x)
8989
@jit partialsort!(x_ra, 3)
9090
partialsort!(x, 3)
91-
@test @allowscalar(x_ra[3]) == x[3]
91+
@test @allowscalar(x_ra[3]) x[3]
9292

9393
x = randn(10)
9494
x_ra = Reactant.to_rarray(x)
@@ -97,13 +97,13 @@ end
9797
ix_ra = Reactant.to_rarray(ix)
9898
@jit partialsortperm!(ix_ra, x_ra, 1:5)
9999
partialsortperm!(ix, x, 1:5)
100-
@test Array(ix_ra)[1:5] == ix[1:5]
100+
@test Array(ix_ra)[1:5] ix[1:5]
101101

102102
ix = similar(x, Int)
103103
ix_ra = Reactant.to_rarray(ix)
104104
@jit partialsortperm!(ix_ra, x_ra, 3)
105105
partialsortperm!(ix, x, 3)
106-
@test @allowscalar(ix_ra[3]) == ix[3]
106+
@test @allowscalar(ix_ra[3]) ix[3]
107107
end
108108

109109
@testset "argmin / argmax" begin

0 commit comments

Comments
 (0)