Skip to content

Commit 575decf

Browse files
committed
chore: run formatter
1 parent 79091da commit 575decf

File tree

3 files changed

+86
-82
lines changed

3 files changed

+86
-82
lines changed

src/TracedRArray.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,7 +1281,6 @@ function scan_impl!(
12811281

12821282
dims > ndims(input) && return copyto!(output, input)
12831283

1284-
12851284
if init === nothing
12861285
op_in_T = Core.Compiler.return_type(op, Tuple{T,T})
12871286
op_in_T === Union{} && (op_in_T = T)
@@ -1292,10 +1291,14 @@ function scan_impl!(
12921291
end
12931292
else
12941293
# TODO: fix this for TPUs
1295-
if contains(string(first(Reactant.devices())), "TPU")
1294+
if contains(string(first(Reactant.devices())), "TPU")
12961295
initT = __default_init(T, op)
12971296
if initT != init && initT != something(init)
1298-
throw(AssertionError("Currently, `init` is not supported on TPUs, provided value $init does not match identity $initT."))
1297+
throw(
1298+
AssertionError(
1299+
"Currently, `init` is not supported on TPUs, provided value $init does not match identity $initT.",
1300+
),
1301+
)
12991302
end
13001303
end
13011304
end

test/basic.jl

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,39 +1292,39 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12921292
end
12931293

12941294
if !contains(string(Reactant.devices()[1]), "TPU")
1295-
@testset "accumulate" begin
1296-
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1297-
accumulate(accum_fn, a; init=0.0f0)
1298-
1299-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1300-
accumulate(accum_fn, b; dims=1, init=0.0f0)
1301-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1302-
accumulate(accum_fn, b; dims=2, init=0.0f0)
1303-
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1304-
accumulate(accum_fn, b; dims=3, init=0.0f0)
1305-
1306-
@test begin
1307-
z = similar(a_ra)
1308-
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1309-
z
1310-
end accumulate(accum_fn, a; init=0.0f0)
1311-
1312-
@test begin
1313-
z = similar(b_ra)
1314-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1315-
z
1316-
end accumulate(accum_fn, b; dims=1, init=0.0f0)
1317-
@test begin
1318-
z = similar(b_ra)
1319-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1320-
z
1321-
end accumulate(accum_fn, b; dims=2, init=0.0f0)
1322-
@test begin
1323-
z = similar(b_ra)
1324-
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1325-
z
1326-
end accumulate(accum_fn, b; dims=3, init=0.0f0)
1327-
end
1295+
@testset "accumulate" begin
1296+
@test @jit(accumulate(accum_fn, a_ra; init=0.0f0))
1297+
accumulate(accum_fn, a; init=0.0f0)
1298+
1299+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=1))
1300+
accumulate(accum_fn, b; dims=1, init=0.0f0)
1301+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=2))
1302+
accumulate(accum_fn, b; dims=2, init=0.0f0)
1303+
@test @jit(accumulate(accum_fn, b_ra; init=0.0f0, dims=3))
1304+
accumulate(accum_fn, b; dims=3, init=0.0f0)
1305+
1306+
@test begin
1307+
z = similar(a_ra)
1308+
@jit(accumulate!(accum_fn, z, a_ra; init=0.0f0))
1309+
z
1310+
end accumulate(accum_fn, a; init=0.0f0)
1311+
1312+
@test begin
1313+
z = similar(b_ra)
1314+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=1))
1315+
z
1316+
end accumulate(accum_fn, b; dims=1, init=0.0f0)
1317+
@test begin
1318+
z = similar(b_ra)
1319+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=2))
1320+
z
1321+
end accumulate(accum_fn, b; dims=2, init=0.0f0)
1322+
@test begin
1323+
z = similar(b_ra)
1324+
@jit(accumulate!(accum_fn, z, b_ra; init=0.0f0, dims=3))
1325+
z
1326+
end accumulate(accum_fn, b; dims=3, init=0.0f0)
1327+
end
13281328
end
13291329
end
13301330

test/ops.jl

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -561,53 +561,53 @@ end
561561
end
562562

563563
if !contains(string(Reactant.devices()[1]), "TPU")
564-
@testset "rng_bit_generator" begin
565-
genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4])
566-
genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4])
567-
genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4])
568-
genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4])
569-
genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4])
570-
571-
@testset for (alg, sz) in
572-
[("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)]
573-
seed = Reactant.to_rarray(zeros(UInt64, sz))
574-
575-
res = @jit genInt32(seed)
576-
@test res.output_state !== seed
577-
@test size(res.output_state) == (sz,)
578-
@test res.output isa ConcreteRArray{Int32,2}
579-
@test size(res.output) == (2, 4)
580-
581-
seed = res.output_state
582-
res = @jit genInt64(seed)
583-
@test res.output_state !== seed
584-
@test size(res.output_state) == (sz,)
585-
@test res.output isa ConcreteRArray{Int64,2}
586-
@test size(res.output) == (2, 4)
587-
588-
seed = res.output_state
589-
res = @jit genUInt64(seed)
590-
@test res.output_state !== seed
591-
@test size(res.output_state) == (sz,)
592-
@test res.output isa ConcreteRArray{UInt64,2}
593-
@test size(res.output) == (2, 4)
594-
595-
seed = res.output_state
596-
res = @jit genFloat32(seed)
597-
@test res.output_state !== seed
598-
@test size(res.output_state) == (sz,)
599-
@test res.output isa ConcreteRArray{Float32,2}
600-
@test size(res.output) == (2, 4)
601-
602-
seed = res.output_state
603-
res = @jit genFloat64(seed)
604-
@test res.output_state !== seed
605-
@test size(res.output_state) == (sz,)
606-
@test res.output isa ConcreteRArray{Float64,2}
607-
@test size(res.output) == (2, 4)
564+
@testset "rng_bit_generator" begin
565+
genInt32(seed) = Ops.rng_bit_generator(Int32, seed, [2, 4])
566+
genInt64(seed) = Ops.rng_bit_generator(Int64, seed, [2, 4])
567+
genUInt64(seed) = Ops.rng_bit_generator(UInt64, seed, [2, 4])
568+
genFloat32(seed) = Ops.rng_bit_generator(Float32, seed, [2, 4])
569+
genFloat64(seed) = Ops.rng_bit_generator(Float64, seed, [2, 4])
570+
571+
@testset for (alg, sz) in
572+
[("DEFAULT", 2), ("PHILOX", 2), ("PHILOX", 3), ("THREE_FRY", 2)]
573+
seed = Reactant.to_rarray(zeros(UInt64, sz))
574+
575+
res = @jit genInt32(seed)
576+
@test res.output_state !== seed
577+
@test size(res.output_state) == (sz,)
578+
@test res.output isa ConcreteRArray{Int32,2}
579+
@test size(res.output) == (2, 4)
580+
581+
seed = res.output_state
582+
res = @jit genInt64(seed)
583+
@test res.output_state !== seed
584+
@test size(res.output_state) == (sz,)
585+
@test res.output isa ConcreteRArray{Int64,2}
586+
@test size(res.output) == (2, 4)
587+
588+
seed = res.output_state
589+
res = @jit genUInt64(seed)
590+
@test res.output_state !== seed
591+
@test size(res.output_state) == (sz,)
592+
@test res.output isa ConcreteRArray{UInt64,2}
593+
@test size(res.output) == (2, 4)
594+
595+
seed = res.output_state
596+
res = @jit genFloat32(seed)
597+
@test res.output_state !== seed
598+
@test size(res.output_state) == (sz,)
599+
@test res.output isa ConcreteRArray{Float32,2}
600+
@test size(res.output) == (2, 4)
601+
602+
seed = res.output_state
603+
res = @jit genFloat64(seed)
604+
@test res.output_state !== seed
605+
@test size(res.output_state) == (sz,)
606+
@test res.output isa ConcreteRArray{Float64,2}
607+
@test size(res.output) == (2, 4)
608+
end
608609
end
609610
end
610-
end
611611

612612
@testset "round_nearest_afz" begin
613613
x = Reactant.to_rarray([-2.5, 0.4, 0.5, 0.6, 2.5])
@@ -1238,7 +1238,8 @@ end
12381238
@test size(perm) == (4, 3, 6)
12391239
@test size(info) == (4, 3)
12401240

1241-
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm)) atol = 1e-5 rtol = 1e-2
1241+
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm)) atol = 1e-5 rtol =
1242+
1e-2
12421243
end
12431244
end
12441245

0 commit comments

Comments
 (0)