Skip to content

Commit 10712d0

Browse files
committed
test: more tests broken
1 parent c7071d7 commit 10712d0

File tree

2 files changed

+81
-75
lines changed

2 files changed

+81
-75
lines changed

test/basic.jl

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,40 +1275,38 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12751275
end cumprod(b; dims=3)
12761276
end
12771277

1278-
if !contains(string(Reactant.devices()[1]), "TPU")
1279-
@testset "accumulate" begin
1280-
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1281-
accumulate(accum_fn, a; init=0.0f0)
1282-
1283-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1284-
accumulate(accum_fn, b; dims=1, init=0.0f0)
1285-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1286-
accumulate(accum_fn, b; dims=2, init=0.0f0)
1287-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1288-
accumulate(accum_fn, b; dims=3, init=0.0f0)
1289-
1290-
@test begin
1291-
z = similar(a_ra)
1292-
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1293-
z
1294-
end accumulate(accum_fn, a; init=0.0f0)
1295-
1296-
@test begin
1297-
z = similar(b_ra)
1298-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1299-
z
1300-
end accumulate(accum_fn, b; dims=1, init=0.0f0)
1301-
@test begin
1302-
z = similar(b_ra)
1303-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1304-
z
1305-
end accumulate(accum_fn, b; dims=2, init=0.0f0)
1306-
@test begin
1307-
z = similar(b_ra)
1308-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1309-
z
1310-
end accumulate(accum_fn, b; dims=3, init=0.0f0)
1311-
end
1278+
@testset "accumulate" begin
1279+
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1280+
accumulate(accum_fn, a; init=0.0f0)
1281+
1282+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1283+
accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
1284+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1285+
accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
1286+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1287+
accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU
1288+
1289+
@test begin
1290+
z = similar(a_ra)
1291+
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1292+
z
1293+
end accumulate(accum_fn, a; init=0.0f0) broken = RunningOnTPU
1294+
1295+
@test begin
1296+
z = similar(b_ra)
1297+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1298+
z
1299+
end accumulate(accum_fn, b; dims=1, init=0.0f0) broken = RunningOnTPU
1300+
@test begin
1301+
z = similar(b_ra)
1302+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1303+
z
1304+
end accumulate(accum_fn, b; dims=2, init=0.0f0) broken = RunningOnTPU
1305+
@test begin
1306+
z = similar(b_ra)
1307+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1308+
z
1309+
end accumulate(accum_fn, b; dims=3, init=0.0f0) broken = RunningOnTPU
13121310
end
13131311
end
13141312

test/ops.jl

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -535,52 +535,60 @@ end
535535
@test [2 1; 4 3] == @jit g2(x)
536536
end
537537

538-
if !contains(string(Reactant.devices()[1]), "TPU")
539-
@testset "rng_bit_generator" begin
540-
genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4])
541-
genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4])
542-
genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4])
543-
genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4])
544-
genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4])
545-
546-
@testset for (alg, sz) in
547-
[("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)]
548-
seed = Reactant.to_rarray(zeros(UInt64, sz))
549-
538+
@testset "rng_bit_generator" begin
539+
genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4])
540+
genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4])
541+
genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4])
542+
genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4])
543+
genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4])
544+
545+
@testset for (alg, sz) in
546+
[("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)]
547+
seed = Reactant.to_rarray(zeros(UInt64, sz))
548+
549+
@test begin
550550
res = @jit genInt32(seed)
551-
@test res.output_state !== seed
552-
@test size(res.output_state) == (sz,)
553-
@test res.output isa ConcreteRArray{Int32,2}
554-
@test size(res.output) == (2, 4)
555-
556-
seed = res.output_state
551+
res.output_state !== seed &&
552+
size(res.output_state) == (sz,) &&
553+
res.output isa ConcreteRArray{Int32,2} &&
554+
size(res.output) == (2, 4)
555+
end broken = RunningOnTPU
556+
557+
seed = res.output_state
558+
@test begin
557559
res = @jit genInt64(seed)
558-
@test res.output_state !== seed
559-
@test size(res.output_state) == (sz,)
560-
@test res.output isa ConcreteRArray{Int64,2}
561-
@test size(res.output) == (2, 4)
562-
563-
seed = res.output_state
560+
res.output_state !== seed &&
561+
size(res.output_state) == (sz,) &&
562+
res.output isa ConcreteRArray{Int64,2} &&
563+
size(res.output) == (2, 4)
564+
end broken = RunningOnTPU
565+
566+
seed = res.output_state
567+
@test begin
564568
res = @jit genUInt64(seed)
565-
@test res.output_state !== seed
566-
@test size(res.output_state) == (sz,)
567-
@test res.output isa ConcreteRArray{UInt64,2}
568-
@test size(res.output) == (2, 4)
569-
570-
seed = res.output_state
569+
res.output_state !== seed &&
570+
size(res.output_state) == (sz,) &&
571+
res.output isa ConcreteRArray{UInt64,2} &&
572+
size(res.output) == (2, 4)
573+
end broken = RunningOnTPU
574+
575+
seed = res.output_state
576+
@test begin
571577
res = @jit genFloat32(seed)
572-
@test res.output_state !== seed
573-
@test size(res.output_state) == (sz,)
574-
@test res.output isa ConcreteRArray{Float32,2}
575-
@test size(res.output) == (2, 4)
576-
577-
seed = res.output_state
578+
res.output_state !== seed &&
579+
size(res.output_state) == (sz,) &&
580+
res.output isa ConcreteRArray{Float32,2} &&
581+
size(res.output) == (2, 4)
582+
end broken = RunningOnTPU
583+
584+
seed = res.output_state
585+
@test begin
578586
res = @jit genFloat64(seed)
579-
@test res.output_state !== seed
580-
@test size(res.output_state) == (sz,)
581-
@test res.output isa ConcreteRArray{Float64,2}
582-
@test size(res.output) == (2, 4)
583-
end
587+
res.output_state !== seed &&
588+
size(res.output_state) == (sz,) &&
589+
res.output isa ConcreteRArray{Float64,2} &&
590+
size(res.output) == (2, 4)
591+
end broken = RunningOnTPU
584592
end
585593
end
586594

0 commit comments

Comments
 (0)