Skip to content

Commit 7c820db

Browse files
wsmosesavik-pal
authored andcommitted
more testing
1 parent 48ca5eb commit 7c820db

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

test/basic.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ end
10271027
times = 0:0.01:4.5
10281028
@test times isa Base.StepRangeLen
10291029
res = @jit fractional_idx(times, ConcreteRNumber(2.143))
1030-
@test res[1] == 0.29999999999997334
1030+
@test res[1] 0.29999999999997334
10311031
@test res[2] == 215
10321032
@test res[3] == 216
10331033
end
@@ -1036,7 +1036,7 @@ end
10361036
times = Reactant.to_rarray(0:0.01:4.5; track_numbers=Number)
10371037
@test times isa Reactant.TracedRNumberOverrides.TracedStepRangeLen
10381038
res = @jit fractional_idx(times, ConcreteRNumber(2.143))
1039-
@test res[1] == 0.29999999999997334
1039+
@test res[1] 0.29999999999997334
10401040
@test res[2] == 215
10411041
@test res[3] == 216
10421042
end
@@ -1224,7 +1224,6 @@ end
12241224
end
12251225

12261226
accum_fn(x, y) = abs2(x) + abs2(y)
1227-
Base.reduce_empty(::typeof(accum_fn), ::Type{T}) where {T} = zero(T)
12281227

12291228
@testset "accumulate" begin
12301229
a = collect(Float32, 1:10) ./ 10
@@ -1292,6 +1291,7 @@ Base.reduce_empty(::typeof(accum_fn), ::Type{T}) where {T} = zero(T)
12921291
end cumprod(b; dims=3)
12931292
end
12941293

1294+
if !contains(string(Reactant.devices()[1]), "TPU")
12951295
@testset "accumulate" begin
12961296
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
12971297
accumulate(accum_fn, a; init=0.0f0)
@@ -1325,6 +1325,7 @@ Base.reduce_empty(::typeof(accum_fn), ::Type{T}) where {T} = zero(T)
13251325
z
13261326
end accumulate(accum_fn, b; dims=3, init=0.0f0)
13271327
end
1328+
end
13281329
end
13291330

13301331
sameunitrange(x, y) = first(x) == first(y) && last(x) == last(y)

test/ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ end
12381238
@test size(perm) == (4, 3, 6)
12391239
@test size(info) == (4, 3)
12401240

1241-
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm))
1241+
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm)) atol = 1e-5 rtol = 1e-2
12421242
end
12431243
end
12441244

test/sorting.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ end
134134
x = randn(2, 3, 4)
135135
x_ra = Reactant.to_rarray(x)
136136

137-
@test argmin(abs2, x) == @jit(argmin(abs2, x_ra))
138-
@test argmax(abs2, x) == @jit(argmax(abs2, x_ra))
137+
@test argmin(abs2, x) @jit(argmin(abs2, x_ra))
138+
@test argmax(abs2, x) @jit(argmax(abs2, x_ra))
139139
end
140140

141141
@testset "findmin / findmax" begin
@@ -152,22 +152,22 @@ end
152152

153153
@test fwithlinindices(findmin, identity, x) == @jit(findmin(x_ra))
154154
@test fwithlinindices(findmax, identity, x) == @jit(findmax(x_ra))
155-
@test fwithlinindices(findmin, identity, xvec) == @jit(findmin(xvec_ra))
156-
@test fwithlinindices(findmax, identity, xvec) == @jit(findmax(xvec_ra))
155+
@test fwithlinindices(findmin, identity, xvec) @jit(findmin(xvec_ra))
156+
@test fwithlinindices(findmax, identity, xvec) @jit(findmax(xvec_ra))
157157

158158
fmindims(x, d) = findmin(x; dims=d)
159159
fmindims(f, x, d) = findmin(f, x; dims=d)
160160
fmaxdims(x, d) = findmax(x; dims=d)
161161
fmaxdims(f, x, d) = findmax(f, x; dims=d)
162162

163-
@test fwithlinindices(findmin, identity, x; dims=1) == @jit(fmindims(x_ra, 1))
164-
@test fwithlinindices(findmax, identity, x; dims=1) == @jit(fmaxdims(x_ra, 1))
165-
@test fwithlinindices(findmin, identity, x; dims=2) == @jit(fmindims(x_ra, 2))
166-
@test fwithlinindices(findmax, identity, x; dims=2) == @jit(fmaxdims(x_ra, 2))
167-
@test fwithlinindices(findmin, abs2, x; dims=1) == @jit(fmindims(abs2, x_ra, 1))
168-
@test fwithlinindices(findmax, abs2, x; dims=1) == @jit(fmaxdims(abs2, x_ra, 1))
169-
@test fwithlinindices(findmin, abs2, x; dims=2) == @jit(fmindims(abs2, x_ra, 2))
170-
@test fwithlinindices(findmax, abs2, x; dims=2) == @jit(fmaxdims(abs2, x_ra, 2))
163+
@test fwithlinindices(findmin, identity, x; dims=1) @jit(fmindims(x_ra, 1))
164+
@test fwithlinindices(findmax, identity, x; dims=1) @jit(fmaxdims(x_ra, 1))
165+
@test fwithlinindices(findmin, identity, x; dims=2) @jit(fmindims(x_ra, 2))
166+
@test fwithlinindices(findmax, identity, x; dims=2) @jit(fmaxdims(x_ra, 2))
167+
@test fwithlinindices(findmin, abs2, x; dims=1) @jit(fmindims(abs2, x_ra, 1))
168+
@test fwithlinindices(findmax, abs2, x; dims=1) @jit(fmaxdims(abs2, x_ra, 1))
169+
@test fwithlinindices(findmin, abs2, x; dims=2) @jit(fmindims(abs2, x_ra, 2))
170+
@test fwithlinindices(findmax, abs2, x; dims=2) @jit(fmaxdims(abs2, x_ra, 2))
171171
end
172172

173173
@testset "findfirst / findlast" begin
@@ -183,8 +183,8 @@ end
183183
flastlinindices(x) = LinearIndices(x)[findlast(x)]
184184
flastlinindices(f, x) = LinearIndices(x)[findlast(f, x)]
185185

186-
@test ffirstlinindices(x) == @jit(findfirst(x_ra))
187-
@test flastlinindices(x) == @jit(findlast(x_ra))
186+
@test ffirstlinindices(x) @jit(findfirst(x_ra))
187+
@test flastlinindices(x) @jit(findlast(x_ra))
188188

189189
x = Int64[
190190
3 5 7 9
@@ -193,8 +193,8 @@ end
193193
]
194194
x_ra = Reactant.to_rarray(x)
195195

196-
@test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra))
197-
@test flastlinindices(iseven, x) == @jit(findlast(iseven, x_ra))
196+
@test ffirstlinindices(iseven, x) @jit(findfirst(iseven, x_ra))
197+
@test flastlinindices(iseven, x) @jit(findlast(iseven, x_ra))
198198
end
199199

200200
@testset "approx top k lowering" begin

test/wrapped_arrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ end
241241
@testset "reshaped subarray indexing" begin
242242
fn(x) = view(x, 1:2) .+ 1
243243
x_ra = Reactant.to_rarray(rand(3, 4, 3))
244-
@test @jit(fn(x_ra)) == fn(Array(x_ra))
244+
@test @jit(fn(x_ra)) fn(Array(x_ra))
245245
end
246246

247247
function reshape_getindex(x)

0 commit comments

Comments
 (0)