Skip to content

Commit 0c3073c

Browse files
authored
Define GPUArrays tests dynamically (#865)
1 parent 66e1de8 commit 0c3073c

File tree

3 files changed

+31
-89
lines changed

3 files changed

+31
-89
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ docs/package-lock.json
1212
docs/node_modules
1313
Manifest*.toml
1414
LocalPreferences.toml
15+
16+
test/gpuarrays_generated_tests.jl

test/gpuarrays_tests.jl

Lines changed: 0 additions & 89 deletions
This file was deleted.

test/runtests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ using LinearAlgebra
1010
using ReTestItems
1111
using Test
1212

13+
import GPUArrays
14+
include(joinpath(pkgdir(GPUArrays), "test", "testsuite.jl"))
15+
1316
macro grab_output(ex, io=stdout)
1417
quote
1518
mktemp() do fname, fout
@@ -113,6 +116,32 @@ PrettyTables.pretty_table(data; column_labels=["Workers", "Device", "Tests"],
113116
fit_table_in_display_vertically=false,
114117
fit_table_in_display_horizontally=false)
115118

119+
# Hack to define GPUArrays `@testitems` dynamically - by writing them to a file.
120+
function write_gpuarrays_tests()
121+
template = """
122+
@testsetup module TSGPUArrays
123+
export gpuarrays_test
124+
125+
import GPUArrays, AMDGPU
126+
include(joinpath(pkgdir(GPUArrays), "test", "testsuite.jl"))
127+
128+
gpuarrays_test(test_name::String) = TestSuite.tests[test_name](AMDGPU.ROCArray)
129+
end
130+
"""
131+
for test_name in keys(TestSuite.tests)
132+
template = """
133+
$template
134+
@testitem "gpuarrays - $test_name" setup=[TSGPUArrays] begin gpuarrays_test("$test_name") end
135+
"""
136+
end
137+
138+
test_file = joinpath(dirname(@__FILE__), "gpuarrays_generated_tests.jl")
139+
open(io -> write(io, template), test_file, "w")
140+
@info "Writing GPUArrays test file: `$test_file`."
141+
return
142+
end
143+
write_gpuarrays_tests()
144+
116145
runtests(AMDGPU; nworkers=np, nworker_threads=1, testitem_timeout=60 * 30) do ti
117146
for tt in TARGET_TESTS
118147
startswith(ti.name, tt) && return true

0 commit comments

Comments
 (0)