diff --git a/src/ParallelTestRunner.jl b/src/ParallelTestRunner.jl index bec95ff..122ac19 100644 --- a/src/ParallelTestRunner.jl +++ b/src/ParallelTestRunner.jl @@ -1,6 +1,6 @@ module ParallelTestRunner -export runtests +export runtests, addworkers, addworker using Distributed using Dates @@ -213,6 +213,35 @@ function default_njobs(; cpu_threads = Sys.CPU_THREADS, free_memory = Sys.free_m return max(1, min(jobs, memory_jobs)) end +""" + addworkers(X; kwargs...) + +Add `X` worker processes, with additional keyword arguments passed to `addprocs`. +""" +test_exeflags = Base.julia_cmd() +filter!(test_exeflags.exec) do c + return !(startswith(c, "--depwarn") || startswith(c, "--check-bounds")) +end +push!(test_exeflags.exec, "--check-bounds=yes") +push!(test_exeflags.exec, "--startup-file=no") +push!(test_exeflags.exec, "--depwarn=yes") +push!(test_exeflags.exec, "--project=$(Base.active_project())") +test_exename = popfirst!(test_exeflags.exec) +function addworkers(X; kwargs...) + exename = test_exename + + return withenv("JULIA_NUM_THREADS" => 1, "OPENBLAS_NUM_THREADS" => 1) do + procs = addprocs(X; exename = exename, exeflags = test_exeflags, kwargs...) + Distributed.remotecall_eval( + Main, procs, quote + import ParallelTestRunner + end + ) + procs + end +end +addworker(; kwargs...) = addworkers(1; kwargs...)[1] + """ runtests(ARGS; testfilter = Returns(true), RecordType = TestRecord, custom_tests = Dict()) @@ -230,6 +259,8 @@ Several keyword arguments are also supported: - `custom_tests`: Optional dictionary of custom tests, mapping test names to expressions. - `init_code`: Code use to initialize each test's sandbox module (e.g., import auxiliary packages, define constants, etc). +- `test_worker`: Optional function that takes a test name and returns a specific worker. + When returning `nothing`, the test will be assigned to any available default worker. ## Command Line Options @@ -272,7 +303,8 @@ Workers are automatically recycled when they exceed memory limits to prevent out issues during long test runs. The memory limit is set based on system architecture. """ function runtests(ARGS; testfilter = Returns(true), RecordType = TestRecord, - custom_tests::Dict{String, Expr}=Dict{String, Expr}(), init_code = :()) + custom_tests::Dict{String, Expr}=Dict{String, Expr}(), init_code = :(), + test_worker = Returns(nothing)) do_help, _ = extract_flag!(ARGS, "--help") if do_help println( @@ -372,29 +404,7 @@ function runtests(ARGS; testfilter = Returns(true), RecordType = TestRecord, @info "Running $jobs tests in parallel. If this is too many, specify the `--jobs=N` argument to the tests, or set the `JULIA_CPU_THREADS` environment variable." # add workers - test_exeflags = Base.julia_cmd() - filter!(test_exeflags.exec) do c - return !(startswith(c, "--depwarn") || startswith(c, "--check-bounds")) - end - push!(test_exeflags.exec, "--check-bounds=yes") - push!(test_exeflags.exec, "--startup-file=no") - push!(test_exeflags.exec, "--depwarn=yes") - push!(test_exeflags.exec, "--project=$(Base.active_project())") - test_exename = popfirst!(test_exeflags.exec) - function addworker(X; kwargs...) - exename = test_exename - - return withenv("JULIA_NUM_THREADS" => 1, "OPENBLAS_NUM_THREADS" => 1) do - procs = addprocs(X; exename = exename, exeflags = test_exeflags, kwargs...) - Distributed.remotecall_eval( - Main, procs, quote - import ParallelTestRunner - end - ) - procs - end - end - addworker(min(jobs, length(tests))) + addworkers(min(jobs, length(tests))) # pretty print information about gc and mem usage testgroupheader = "Test" @@ -492,21 +502,21 @@ function runtests(ARGS; testfilter = Returns(true), RecordType = TestRecord, while length(tests) > 0 test = popfirst!(tests) - # sometimes a worker failed, and we need to spawn a new one + # if a worker failed, spawn a new one if p === nothing - p = addworker(1)[1] + p = addworkers(1)[1] end - wrkr = p - local resp + # some tests may need a special worker + wrkr = something(test_worker(test), p) # run the test running_tests[test] = now() - try - resp = remotecall_fetch(runtest, wrkr, RecordType, test_runners[test], test, init_code) + resp = try + remotecall_fetch(runtest, wrkr, RecordType, test_runners[test], test, init_code) catch e isa(e, InterruptException) && return - resp = Any[e] + Any[e] end delete!(running_tests, test) push!(results, (test, resp)) @@ -529,6 +539,11 @@ function runtests(ARGS; testfilter = Returns(true), RecordType = TestRecord, # so future tests get a fresh environment p = recycle_worker(p) end + + # get rid of the custom worker + if wrkr != p + recycle_worker(wrkr) + end end if p !== nothing diff --git a/test/runtests.jl b/test/runtests.jl index 71f9757..f1a7fb5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,3 +20,20 @@ custom_tests = Dict( end ) runtests(ARGS; init_code, custom_tests) + +# custom worker +function test_worker(name) + if name == "needs env var" + return addworker(env=["SPECIAL_ENV_VAR"=>"42"]) + end + return nothing +end +custom_tests = Dict( + "needs env var" => quote + @test ENV["SPECIAL_ENV_VAR"] == "42" + end, + "doesn't need env var" => quote + @test !haskey(ENV, "SPECIAL_ENV_VAR") + end +) +runtests(ARGS; test_worker, custom_tests)