diff --git a/Project.toml b/Project.toml index 30ff4d5..007646b 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Dates = "1" IOCapture = "0.2.5" -Malt = "1.2.1" +Malt = "1.3.0" Printf = "1" Random = "1" Scratch = "1.3.0" diff --git a/src/ParallelTestRunner.jl b/src/ParallelTestRunner.jl index 9d5e078..ae91180 100644 --- a/src/ParallelTestRunner.jl +++ b/src/ParallelTestRunner.jl @@ -414,15 +414,36 @@ worker_id(wrkr) = WORKER_IDS[wrkr.proc_pid] Add `X` worker processes. """ addworkers(X; kwargs...) = [addworker(; kwargs...) for _ in 1:X] -function addworker(; env=Vector{Pair{String, String}}()) + +""" + addworker(; env=Vector{Pair{String, String}}(), exename=nothing, exeflags=nothing) + +Add a single worker process. + +## Arguments +- `env`: Vector of environment variable pairs to set for the worker process. +- `exename`: Custom executable to use for the worker process. +- `exeflags`: Custom flags to pass to the worker process. +""" +function addworker(; + env = Vector{Pair{String, String}}(), + exename = nothing, exeflags = nothing + ) exe = test_exe() - exeflags = exe[2:end] + if exename === nothing + exename = exe[1] + end + if exeflags !== nothing + exeflags = vcat(exe[2:end], exeflags) + else + exeflags = exe[2:end] + end push!(env, "JULIA_NUM_THREADS" => "1") # Malt already sets OPENBLAS_NUM_THREADS to 1 push!(env, "OPENBLAS_NUM_THREADS" => "1") - wrkr = Malt.Worker(;exeflags, env) + wrkr = Malt.Worker(; exename, exeflags, env) WORKER_IDS[wrkr.proc_pid] = length(WORKER_IDS) + 1 return wrkr end diff --git a/test/runtests.jl b/test/runtests.jl index 5971bfe..6c83431 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,6 +68,8 @@ end function test_worker(name) if name == "needs env var" return addworker(env = ["SPECIAL_ENV_VAR" => "42"]) + elseif name == "threads/2" + return addworker(exeflags = ["--threads=2"]) end return nothing end @@ -77,6 +79,12 @@ end end, "doesn't need env var" => quote @test !haskey(ENV, "SPECIAL_ENV_VAR") + end, + "threads/1" => quote + @test Base.Threads.nthreads() == 1 + end, + "threads/2" => quote + @test Base.Threads.nthreads() == 2 end ) @@ -86,6 +94,8 @@ end str = String(take!(io)) @test contains(str, r"needs env var .+ started at") @test contains(str, r"doesn't need env var .+ started at") + @test contains(str, r"threads/1 .+ started at") + @test contains(str, r"threads/2 .+ started at") @test contains(str, "SUCCESS") end