Skip to content

Commit f0c5735

Browse files
giordanowsmoses
andauthored
Make tests even quieter (#2609)
* Remove debugging print from tests * Capture and test warning * Replace `size(T::VectorType)` with `length(T)` as required by LLVM.jl The deprecation was introduced in LLVM.jl v9.1, we require GPUCompiler.jl v1.6 which require LLVM.jl v9.3, so the requirement bump, together with dropping previous versions, is more than safe. * Capture and test more warnings, emitted during compilation * Update test/runtests.jl --------- Co-authored-by: William Moses <[email protected]>
1 parent 4a74662 commit f0c5735

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ EnzymeCore = "0.8.13"
4848
Enzyme_jll = "0.0.200"
4949
GPUArraysCore = "0.1.6, 0.2"
5050
GPUCompiler = "1.6"
51-
LLVM = "6.1, 7, 8, 9"
51+
LLVM = "9.1"
5252
LogExpFunctions = "0.3"
5353
ObjectFile = "0.4, 0.5"
5454
PrecompileTools = "1"

src/typeutils/lltypes.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,12 @@ function CountTrackedPointers(@nospecialize(T::LLVM.LLVMType))
2929
res.all &= sub.all
3030
res.derived |= sub.derived
3131
end
32-
elseif isa(T, LLVM.ArrayType)
32+
elseif isa(T, LLVM.ArrayType) || isa(T, LLVM.VectorType)
3333
sub = CountTrackedPointers(eltype(T))
3434
res.count += sub.count
3535
res.all &= sub.all
3636
res.derived |= sub.derived
3737
res.count *= length(T)
38-
elseif isa(T, LLVM.VectorType)
39-
sub = CountTrackedPointers(eltype(T))
40-
res.count += sub.count
41-
res.all &= sub.all
42-
res.derived |= sub.derived
43-
res.count *= size(T)
4438
end
4539
if res.count == 0
4640
res.all = false

test/optimize.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ struct MvLocationScale{
7272
end
7373

7474
@noinline function law(dist, flat::AbstractVector)
75-
ccall(:jl_, Cvoid, (Any,), flat)
7675
n_dims = div(length(flat), 2)
7776
data = first(flat, n_dims)
7877
scale = Diagonal(data)
@@ -225,14 +224,14 @@ function fwd(x, y)
225224
end
226225

227226
@testset "Parameter removal" begin
228-
# Test that we do not remove parameters, or replace with undef, any parameters from externally linked code (even if replaced via blas)
229-
fn = sprint() do io
230-
Enzyme.Compiler.enzyme_code_llvm(io, fwd, Const, Tuple{Const{Vector{ComplexF64}},Const{Vector{ComplexF64}}}; dump_module=true)
231-
end
232-
233-
for s in split(fn, "\n")
234-
if occursin(s, "ejlstr")
235-
@test !(occursin(" undef",s) || occursin(" poison",s))
236-
end
237-
end
227+
# Test that we do not remove parameters, or replace with undef, any parameters from externally linked code (even if replaced via blas)
228+
fn = sprint() do io
229+
@test_warn r"Using fallback BLAS replacements for" Enzyme.Compiler.enzyme_code_llvm(io, fwd, Const, Tuple{Const{Vector{ComplexF64}}, Const{Vector{ComplexF64}}}; dump_module = true)
230+
end
231+
232+
for s in split(fn, "\n")
233+
if occursin(s, "ejlstr")
234+
@test !(occursin(" undef", s) || occursin(" poison", s))
235+
end
236+
end
238237
end

test/runtests.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,6 @@ end
606606
end
607607

608608
@testset "Simple Exception" begin
609-
f_simple_exc(x, i) = ccall(:jl_, Cvoid, (Any,), x[i])
610609
y = [1.0, 2.0]
611610
f_x = zero.(y)
612611
@test_throws BoundsError autodiff(Reverse, f_simple_exc, Duplicated(y, f_x), Const(0))
@@ -1663,7 +1662,12 @@ end
16631662
end
16641663

16651664
@testset "Vector to Number" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
1666-
test_matrix_to_number(f, y; rtol=1e-6, atol=1e-6)
1665+
# `test_matrix_to_number` contains a `@generated` function, we wrap it in a
1666+
# `Ref{Any}` container only to be able to catch and test the warnings emitted during
1667+
# compilation in the body of the function.
1668+
test_mat2num = Ref{Any}(test_matrix_to_number)
1669+
warn_msg = f === DiffTests.vec2num_3 ? r"Using fallback BLAS replacements for" : ""
1670+
@test_warn warn_msg test_mat2num[](f, y; rtol = 1.0e-6, atol = 1.0e-6)
16671671
end
16681672

16691673
@testset "Matrix to Number" for f in DiffTests.MATRIX_TO_NUMBER_FUNCS
@@ -3423,7 +3427,7 @@ function uns_sum2(x::Array{T})::T where T
34233427
end
34243428

34253429
function uns_ad_forward(scale_diag::Vector{T}, c) where T
3426-
ccall(:jl_, Cvoid, (Any,), scale_diag)
3430+
ccall(:jl_, Cvoid, (Any,), scale_diag)
34273431
res = uns_mymean(uns_sum2, [scale_diag,], T, c)
34283432
return res
34293433
end

0 commit comments

Comments
 (0)