Skip to content

Commit 4773238

Browse files
committed
test: more fixes
1 parent 3bda5b3 commit 4773238

File tree

5 files changed

+107
-80
lines changed

5 files changed

+107
-80
lines changed

src/xla/XLA.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ function __init__()
131131
XLA_REACTANT_GPU_MEM_FRACTION[] = parse(
132132
Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"]
133133
)
134-
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
134+
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] maxlog =
135+
1
135136
if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0
136137
error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1")
137138
end
@@ -141,16 +142,18 @@ function __init__()
141142
XLA_REACTANT_GPU_PREALLOCATE[] = parse(
142143
Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"]
143144
)
144-
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
145+
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] maxlog =
146+
1
145147
end
146148

147149
if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
148150
global_state.local_gpu_device_ids =
149151
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
150-
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids
152+
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids maxlog =
153+
1
151154
end
152155

153-
@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME
156+
@debug "REACTANT_XLA_RUNTIME: " REACTANT_XLA_RUNTIME maxlog = 1
154157

155158
@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
156159
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid

test/integration/linear_algebra.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ end
7777
@test @jit(muladd_5arg(A_ra, x_ra, b_ra)) muladd2(A, x, b)
7878

7979
C_ra = similar(A_ra, Float32, size(A, 1), size(x, 2))
80+
C = similar(A, Float32, size(A, 1), size(x, 2))
8081
@jit(mul!(C_ra, A_ra, x_ra))
81-
@test C_ra A * x atol = 1e-5 rtol = 1e-3
82+
mul!(C, A, x)
83+
@test C_ra C atol = 1e-3 rtol = 1e-2
8284
end
8385

8486
@testset "triu & tril" begin

test/integration/special_functions.jl

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using SpecialFunctions, Reactant
22

3+
const RunningOnTPU = contains(string(Reactant.devices()[1]), "TPU")
4+
35
macro (a, b)
46
return quote
57
isapprox($a, $b; atol=1e-14)
@@ -8,81 +10,92 @@ end
810

911
@testset "gamma" begin
1012
@test SpecialFunctions.gamma(0.5) @jit(SpecialFunctions.gamma(ConcreteRNumber(0.5)))
11-
@test SpecialFunctions.gamma(2) @jit(SpecialFunctions.gamma(ConcreteRNumber(2)))
13+
@test SpecialFunctions.gamma(Int32(2))
14+
@jit(SpecialFunctions.gamma(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
1215
end
1316

1417
@testset "loggamma" begin
1518
@test SpecialFunctions.loggamma(0.5)
16-
@jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5)))
17-
@test abs(SpecialFunctions.loggamma(2)) < 1e-10
18-
@test abs(@jit(SpecialFunctions.loggamma(ConcreteRNumber(2)))) < 1e-10
19+
@jit(SpecialFunctions.loggamma(ConcreteRNumber(0.5))) atol = 1e-5 rtol = 1e-3
20+
@test SpecialFunctions.loggamma(Int32(2))
21+
@jit(SpecialFunctions.loggamma(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
1922
end
2023

2124
@testset "digamma" begin
2225
@test SpecialFunctions.digamma(0.5)
2326
@jit(SpecialFunctions.digamma(ConcreteRNumber(0.5)))
24-
@test SpecialFunctions.digamma(2) @jit(SpecialFunctions.digamma(ConcreteRNumber(2)))
27+
@test SpecialFunctions.digamma(Int32(2))
28+
@jit(SpecialFunctions.digamma(ConcreteRNumber(Int32(2))))
2529
end
2630

2731
@testset "trigamma" begin
2832
@test SpecialFunctions.trigamma(0.5)
2933
@jit(SpecialFunctions.trigamma(ConcreteRNumber(0.5)))
30-
@test SpecialFunctions.trigamma(2) @jit(SpecialFunctions.trigamma(ConcreteRNumber(2)))
34+
@test SpecialFunctions.trigamma(Int32(2))
35+
@jit(SpecialFunctions.trigamma(ConcreteRNumber(Int32(2))))
3136
end
3237

3338
@testset "beta" begin
3439
@test SpecialFunctions.beta(0.5, 0.6)
3540
@jit(SpecialFunctions.beta(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
36-
@test SpecialFunctions.beta(2, 4)
37-
@jit(SpecialFunctions.beta(ConcreteRNumber(2), ConcreteRNumber(4)))
41+
@test SpecialFunctions.beta(Int32(2), Int32(4))
42+
@jit(SpecialFunctions.beta(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4))))
3843
end
3944

4045
@testset "logbeta" begin
4146
@test SpecialFunctions.logbeta(0.5, 0.6)
4247
@jit(SpecialFunctions.logbeta(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
43-
@test SpecialFunctions.logbeta(2, 4)
44-
@jit(SpecialFunctions.logbeta(ConcreteRNumber(2), ConcreteRNumber(4)))
48+
@test SpecialFunctions.logbeta(Int32(2), Int32(4)) @jit(
49+
SpecialFunctions.logbeta(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4)))
50+
)
4551
end
4652

4753
@testset "erf" begin
4854
@test SpecialFunctions.erf(0.5) @jit(SpecialFunctions.erf(ConcreteRNumber(0.5)))
49-
@test SpecialFunctions.erf(2) @jit(SpecialFunctions.erf(ConcreteRNumber(2)))
55+
@test SpecialFunctions.erf(Int32(2))
56+
@jit(SpecialFunctions.erf(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
5057
end
5158

5259
@testset "erf with 2 arguments" begin
5360
@test SpecialFunctions.erf(0.5, 0.6)
5461
@jit(SpecialFunctions.erf(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
55-
@test SpecialFunctions.erf(2, 4)
56-
@jit(SpecialFunctions.erf(ConcreteRNumber(2), ConcreteRNumber(4)))
62+
@test SpecialFunctions.erf(Int32(2), Int32(4))
63+
@jit(SpecialFunctions.erf(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4)))) atol =
64+
1e-5 rtol = 1e-3
5765
end
5866

5967
@testset "erfc" begin
6068
@test SpecialFunctions.erfc(0.5) @jit(SpecialFunctions.erfc(ConcreteRNumber(0.5)))
61-
@test SpecialFunctions.erfc(2) @jit(SpecialFunctions.erfc(ConcreteRNumber(2)))
69+
@test SpecialFunctions.erfc(Int32(2))
70+
@jit(SpecialFunctions.erfc(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
6271
end
6372

6473
@testset "logerf" begin
6574
@test SpecialFunctions.logerf(0.5, 0.6)
6675
@jit(SpecialFunctions.logerf(ConcreteRNumber(0.5), ConcreteRNumber(0.6)))
67-
@test SpecialFunctions.logerf(2, 4)
68-
@jit(SpecialFunctions.logerf(ConcreteRNumber(2), ConcreteRNumber(4)))
76+
@test SpecialFunctions.logerf(Int32(2), Int32(4)) @jit(
77+
SpecialFunctions.logerf(ConcreteRNumber(Int32(2)), ConcreteRNumber(Int32(4)))
78+
) atol = 1e-5 rtol = 1e-3
6979
end
7080

7181
@testset "erfcx" begin
7282
@test SpecialFunctions.erfcx(0.5) @jit(SpecialFunctions.erfcx(ConcreteRNumber(0.5)))
73-
@test SpecialFunctions.erfcx(2) @jit(SpecialFunctions.erfcx(ConcreteRNumber(2)))
83+
@test SpecialFunctions.erfcx(Int32(2))
84+
@jit(SpecialFunctions.erfcx(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
7485
end
7586

7687
@testset "logerfc" begin
7788
@test SpecialFunctions.logerfc(0.5)
7889
@jit(SpecialFunctions.logerfc(ConcreteRNumber(0.5)))
79-
@test SpecialFunctions.logerfc(2) @jit(SpecialFunctions.logerfc(ConcreteRNumber(2)))
90+
@test SpecialFunctions.logerfc(Int32(2))
91+
@jit(SpecialFunctions.logerfc(ConcreteRNumber(Int32(2))))
8092
end
8193

8294
@testset "logerfcx" begin
8395
@test SpecialFunctions.logerfcx(0.5)
8496
@jit(SpecialFunctions.logerfcx(ConcreteRNumber(0.5)))
85-
@test SpecialFunctions.logerfcx(2) @jit(SpecialFunctions.logerfcx(ConcreteRNumber(2)))
97+
@test SpecialFunctions.logerfcx(Int32(2))
98+
@jit(SpecialFunctions.logerfcx(ConcreteRNumber(Int32(2)))) atol = 1e-5 rtol = 1e-3
8699
end
87100

88101
@testset "loggamma1p" begin
@@ -91,8 +104,10 @@ end
91104
end
92105

93106
@testset "loggammadiv" begin
94-
@test SpecialFunctions.loggammadiv(150, 20)
95-
@jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20))
107+
@test SpecialFunctions.loggammadiv(Int32(150), Int32(20))
108+
@jit SpecialFunctions.loggammadiv(
109+
ConcreteRNumber(Int32(150)), ConcreteRNumber(Int32(20))
110+
)
96111
end
97112

98113
@testset "zeta" begin

test/nn/nnlib.jl

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,34 @@ end
123123
end
124124

125125
@testset "Batched Matrix Multiplication" begin
126-
x = rand(Float32, 4, 3, 5)
127-
y = rand(Float32, 3, 2, 5)
126+
Reactant.with_config(;
127+
convolution_precision=PrecisionConfig.HIGHEST,
128+
dot_general_precision=PrecisionConfig.HIGHEST,
129+
) do
130+
x = rand(Float32, 4, 3, 5)
131+
y = rand(Float32, 3, 2, 5)
128132

129-
x_ra = Reactant.to_rarray(x)
130-
y_ra = Reactant.to_rarray(y)
133+
x_ra = Reactant.to_rarray(x)
134+
y_ra = Reactant.to_rarray(y)
131135

132-
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y) atol = 1e-5 rtol = 1e-3
136+
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y)
133137

134-
x = rand(Float32, 4, 3, 1)
135-
y = rand(Float32, 3, 2, 5)
138+
x = rand(Float32, 4, 3, 1)
139+
y = rand(Float32, 3, 2, 5)
136140

137-
x_ra = Reactant.to_rarray(x)
138-
y_ra = Reactant.to_rarray(y)
141+
x_ra = Reactant.to_rarray(x)
142+
y_ra = Reactant.to_rarray(y)
139143

140-
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y) atol = 1e-5 rtol = 1e-3
144+
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y)
141145

142-
x = rand(Float32, 4, 3, 5)
143-
y = rand(Float32, 3, 2, 1)
146+
x = rand(Float32, 4, 3, 5)
147+
y = rand(Float32, 3, 2, 1)
144148

145-
x_ra = Reactant.to_rarray(x)
146-
y_ra = Reactant.to_rarray(y)
149+
x_ra = Reactant.to_rarray(x)
150+
y_ra = Reactant.to_rarray(y)
147151

148-
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y) atol = 1e-5 rtol = 1e-3
152+
@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y)
153+
end
149154
end
150155

151156
@testset "Constant Padding: NNlib.pad_constant" begin
@@ -649,16 +654,18 @@ end
649654
) in Iterators.product(
650655
(0, 2), (1, 2), (1,), (1,)
651656
)
652-
conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)
657+
Reactant.with_config(; convolution_precision=PrecisionConfig.HIGHEST) do
658+
conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)
653659

654-
output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
655-
dy = randn(Float32, output_size)
656-
dy_reactant = Reactant.to_rarray(dy)
660+
output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
661+
dy = randn(Float32, output_size)
662+
dy_reactant = Reactant.to_rarray(dy)
657663

658-
@test @jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims))
659-
NNlib.∇conv_data(dy, w, conv_dims) atol = 1e-5 rtol = 1e-3
660-
@test @jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
661-
NNlib.∇conv_filter(x, dy, conv_dims) atol = 1e-5 rtol = 1e-3
664+
@test @jit(NNlib.∇conv_data(dy_reactant, w_reactant, conv_dims))
665+
NNlib.∇conv_data(dy, w, conv_dims)
666+
@test @jit(NNlib.∇conv_filter(x_reactant, dy_reactant, conv_dims))
667+
NNlib.∇conv_filter(x, dy, conv_dims)
668+
end
662669
end
663670
end
664671

@@ -704,12 +711,12 @@ end
704711
end
705712

706713
@testset "Pixel shuffle" begin
707-
x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
714+
x = Int32[10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
708715
x_ra = Reactant.to_rarray(x)
709716

710717
@test @jit(NNlib.pixel_shuffle(x_ra, 2)) NNlib.pixel_shuffle(x, 2)
711718

712-
y = [i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1]
719+
y = Int32[i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1]
713720
y_ra = Reactant.to_rarray(y)
714721

715722
@test @jit(NNlib.pixel_shuffle(y_ra, 2)) NNlib.pixel_shuffle(y, 2)

test/runtests.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,31 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
1212
@safetestset "Metal Plugin" include("plugins/metal.jl")
1313
end
1414

15-
# @safetestset "Layout" include("layout.jl")
16-
# @safetestset "Tracing" include("tracing.jl")
17-
# @safetestset "Basic" include("basic.jl")
18-
# @safetestset "Constructor" include("constructor.jl")
19-
# @safetestset "Autodiff" include("autodiff.jl")
20-
# @safetestset "Complex" include("complex.jl")
21-
# @safetestset "Broadcast" include("bcast.jl")
22-
# @safetestset "Struct" include("struct.jl")
23-
# @safetestset "Closure" include("closure.jl")
24-
# @safetestset "Compile" include("compile.jl")
25-
# @safetestset "IR" include("ir.jl")
26-
# @safetestset "Buffer Donation" include("buffer_donation.jl")
27-
# @safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
28-
# @safetestset "Control Flow" include("control_flow.jl")
29-
# @safetestset "Sorting" include("sorting.jl")
30-
# @safetestset "Shortcuts to MLIR ops" include("ops.jl")
31-
# @safetestset "Indexing" include("indexing.jl")
32-
# if !Sys.isapple()
33-
# @safetestset "Custom Number Types" include("custom_number_types.jl")
34-
# end
35-
# @safetestset "Sharding" include("sharding.jl")
36-
# @safetestset "Comm Optimization" include("optimize_comm.jl")
37-
# @safetestset "Cluster Detection" include("cluster_detector.jl")
38-
# @safetestset "Config" include("config.jl")
39-
# @safetestset "Batching" include("batching.jl")
15+
@safetestset "Layout" include("layout.jl")
16+
@safetestset "Tracing" include("tracing.jl")
17+
@safetestset "Basic" include("basic.jl")
18+
@safetestset "Constructor" include("constructor.jl")
19+
@safetestset "Autodiff" include("autodiff.jl")
20+
@safetestset "Complex" include("complex.jl")
21+
@safetestset "Broadcast" include("bcast.jl")
22+
@safetestset "Struct" include("struct.jl")
23+
@safetestset "Closure" include("closure.jl")
24+
@safetestset "Compile" include("compile.jl")
25+
@safetestset "IR" include("ir.jl")
26+
@safetestset "Buffer Donation" include("buffer_donation.jl")
27+
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
28+
@safetestset "Control Flow" include("control_flow.jl")
29+
@safetestset "Sorting" include("sorting.jl")
30+
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
31+
@safetestset "Indexing" include("indexing.jl")
32+
if !Sys.isapple()
33+
@safetestset "Custom Number Types" include("custom_number_types.jl")
34+
end
35+
@safetestset "Sharding" include("sharding.jl")
36+
@safetestset "Comm Optimization" include("optimize_comm.jl")
37+
@safetestset "Cluster Detection" include("cluster_detector.jl")
38+
@safetestset "Config" include("config.jl")
39+
@safetestset "Batching" include("batching.jl")
4040
end
4141

4242
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@@ -55,15 +55,15 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5555
end
5656

5757
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
58-
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
59-
@info "NNlib Primitives tests finished"
6058
@safetestset "Flux.jl Integration" include("nn/flux.jl")
6159
@info "Flux.jl Integration tests finished"
6260
if Sys.islinux()
63-
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
64-
@info "LuxLib Primitives tests finished"
6561
@safetestset "Lux Integration" include("nn/lux.jl")
6662
@info "Lux Integration tests finished"
63+
@safetestset "LuxLib Primitives" include("nn/luxlib.jl") # XXX: TPU takes too long
64+
@info "LuxLib Primitives tests finished"
6765
end
66+
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
67+
@info "NNlib Primitives tests finished"
6868
end
6969
end

0 commit comments

Comments
 (0)