Skip to content

Commit d1f1a8a

Browse files
committed
Manage NVIDIA GPU slots in local executor
Signed-off-by: Ben Sherman <bentshermann@gmail.com>
1 parent 1c8e4d3 commit d1f1a8a

File tree

5 files changed

+140
-8
lines changed

5 files changed

+140
-8
lines changed

docs/reference/config.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,12 @@ The following settings are available:
661661
: *Used only by grid executors.*
662662
: Determines how long to wait before returning an error status when a process is terminated but the `.exitcode` file does not exist or is empty (default: `270 sec`).
663663

664+
`executor.gpus`
665+
: :::{versionadded} 25.10.0
666+
:::
667+
: *Used only by the `local` executor.*
668+
: The maximum number of NVIDIA GPUs made available by the underlying system. When this setting is enabled, each local task is assigned GPUs based on their `accelerator` request, using the `CUDA_VISIBLE_DEVICES` environment variable.
669+
664670
`executor.jobName`
665671
: Determines the name of jobs submitted to the underlying cluster executor e.g. `executor.jobName = { "$task.name - $task.hash" }`. Make sure the resulting job name matches the validation constraints of the underlying batch scheduler.
666672
: This setting is supported by the following executors: Bridge, Condor, Flux, HyperQueue, Lsf, Moab, Nqsii, Oar, PBS, PBS Pro, SGE, SLURM and Google Batch.

modules/nextflow/src/main/groovy/nextflow/executor/local/LocalTaskHandler.groovy

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
8080

8181
private volatile TaskResult result
8282

83+
List<Integer> gpuSlots
84+
8385
LocalTaskHandler(TaskRun task, LocalExecutor executor) {
8486
super(task)
8587
// create the task handler
@@ -142,11 +144,13 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
142144
final workDir = task.workDir.toFile()
143145
final logFile = new File(workDir, TaskRun.CMD_LOG)
144146

145-
return new ProcessBuilder()
147+
final pb = new ProcessBuilder()
146148
.redirectErrorStream(true)
147149
.redirectOutput(logFile)
148150
.directory(workDir)
149151
.command(cmd)
152+
applyGpuSlots(pb)
153+
return pb
150154
}
151155

152156
protected ProcessBuilder fusionProcessBuilder() {
@@ -162,10 +166,18 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
162166

163167
final logPath = Files.createTempFile('nf-task','.log')
164168

165-
return new ProcessBuilder()
169+
final pb = new ProcessBuilder()
166170
.redirectErrorStream(true)
167171
.redirectOutput(logPath.toFile())
168172
.command(List.of('sh','-c', cmd))
173+
applyGpuSlots(pb)
174+
return pb
175+
}
176+
177+
protected void applyGpuSlots(ProcessBuilder pb) {
178+
if( !gpuSlots )
179+
return
180+
pb.environment().put('CUDA_VISIBLE_DEVICES', gpuSlots.join(','))
169181
}
170182

171183
protected ProcessBuilder createLaunchProcessBuilder() {

modules/nextflow/src/main/groovy/nextflow/processor/LocalPollingMonitor.groovy

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ import groovy.transform.PackageScope
2323
import groovy.util.logging.Slf4j
2424
import nextflow.Session
2525
import nextflow.exception.ProcessUnrecoverableException
26+
import nextflow.executor.local.LocalTaskHandler
2627
import nextflow.util.Duration
2728
import nextflow.util.MemoryUnit
29+
import nextflow.util.TrackingSemaphore
2830

2931
/**
3032
* Task polling monitor specialized for local execution. It manages tasks scheduling
@@ -58,6 +60,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
5860
*/
5961
private final long maxMemory
6062

63+
/**
64+
* Number of `free` GPUs available to execute pending tasks
65+
*/
66+
private TrackingSemaphore availGpus
67+
68+
/**
69+
* Total number of CPUs available in the system
70+
*/
71+
private final int maxGpus
72+
6173
/**
6274
* Create the task polling monitor with the provided named parameters object.
6375
* <p>
@@ -74,6 +86,8 @@ class LocalPollingMonitor extends TaskPollingMonitor {
7486
super(params)
7587
this.availCpus = maxCpus = params.cpus as int
7688
this.availMemory = maxMemory = params.memory as long
89+
this.maxGpus = params.gpus as int
90+
this.availGpus = new TrackingSemaphore(maxGpus)
7791
assert availCpus>0, "Local avail `cpus` attribute cannot be zero"
7892
assert availMemory>0, "Local avail `memory` attribute cannot zero"
7993
}
@@ -98,14 +112,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
98112

99113
final int cpus = configCpus(session,name)
100114
final long memory = configMem(session,name)
115+
final int gpus = configGpus(session,name)
101116
final int size = session.getQueueSize(name, OS.getAvailableProcessors())
102117

103-
log.debug "Creating local task monitor for executor '$name' > cpus=$cpus; memory=${new MemoryUnit(memory)}; capacity=$size; pollInterval=$pollInterval; dumpInterval=$dumpInterval"
118+
log.debug "Creating local task monitor for executor '$name' > cpus=$cpus; memory=${new MemoryUnit(memory)}; gpus=$gpus; capacity=$size; pollInterval=$pollInterval; dumpInterval=$dumpInterval"
104119

105120
new LocalPollingMonitor(
106121
name: name,
107122
cpus: cpus,
108123
memory: memory,
124+
gpus: gpus,
109125
session: session,
110126
capacity: size,
111127
pollInterval: pollInterval,
@@ -128,6 +144,11 @@ class LocalPollingMonitor extends TaskPollingMonitor {
128144
(session.getExecConfigProp(name, 'memory', OS.getTotalPhysicalMemorySize()) as MemoryUnit).toBytes()
129145
}
130146

147+
@PackageScope
148+
static int configGpus(Session session, String name) {
149+
return session.getExecConfigProp(name, 'gpus', 0) as int
150+
}
151+
131152
/**
132153
* @param handler
133154
* A {@link TaskHandler} instance
@@ -149,6 +170,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
149170
handler.task.getConfig()?.getMemory()?.toBytes() ?: 1L
150171
}
151172

173+
/**
174+
* @param handler
175+
* A {@link TaskHandler} instance
176+
* @return
177+
* The number of gpus requested to execute the specified task
178+
*/
179+
private static int gpus(TaskHandler handler) {
180+
handler.task.getConfig()?.getAccelerator()?.getRequest() ?: 0
181+
}
182+
152183
/**
153184
* Determines if a task can be submitted for execution checking if the resources required
154185
* (cpus and memory) match the amount of avail resource
@@ -174,9 +205,14 @@ class LocalPollingMonitor extends TaskPollingMonitor {
174205
if( taskMemory>maxMemory)
175206
throw new ProcessUnrecoverableException("Process requirement exceeds available memory -- req: ${new MemoryUnit(taskMemory)}; avail: ${new MemoryUnit(maxMemory)}")
176207

177-
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory
208+
final taskGpus = gpus(handler)
209+
if( taskGpus>maxGpus )
210+
throw new ProcessUnrecoverableException("Process requirement exceeds available GPUs -- req: $taskGpus; avail: $maxGpus")
211+
212+
final availGpus0 = availGpus.availablePermits()
213+
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory && taskGpus <= availGpus0
178214
if( !result && log.isTraceEnabled( ) ) {
179-
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)}"
215+
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)} && taskGpus: $taskGpus <= availGpus: ${availGpus0}"
180216
}
181217
return result
182218
}
@@ -189,6 +225,11 @@ class LocalPollingMonitor extends TaskPollingMonitor {
189225
*/
190226
@Override
191227
protected void submit(TaskHandler handler) {
228+
// GPU slots must be assigned before the task is submitted
229+
final taskGpus = gpus(handler)
230+
if ( taskGpus > 0 )
231+
((LocalTaskHandler) handler).gpuSlots = availGpus.acquire(taskGpus)
232+
192233
super.submit(handler)
193234
availCpus -= cpus(handler)
194235
availMemory -= mem(handler)
@@ -204,11 +245,13 @@ class LocalPollingMonitor extends TaskPollingMonitor {
204245
* {@code true} when the task is successfully removed from polling queue,
205246
* {@code false} otherwise
206247
*/
248+
@Override
207249
protected boolean remove(TaskHandler handler) {
208250
final result = super.remove(handler)
209251
if( result ) {
210252
availCpus += cpus(handler)
211253
availMemory += mem(handler)
254+
availGpus.release(((LocalTaskHandler) handler).gpuSlots ?: Collections.<Integer>emptyList())
212255
}
213256
return result
214257
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nextflow.util
18+
19+
import java.util.concurrent.Semaphore
20+
21+
import groovy.transform.CompileStatic
22+
23+
/**
24+
* Specialized semaphore that keeps track of which slots
25+
* are being used.
26+
*
27+
* @author Ben Sherman <bentshermann@gmail.com>
28+
*/
29+
@CompileStatic
30+
class TrackingSemaphore {
31+
private final Semaphore semaphore
32+
private final Map<Integer,Boolean> availIds
33+
34+
TrackingSemaphore(int permits) {
35+
semaphore = new Semaphore(permits)
36+
availIds = new HashMap<>(permits)
37+
for( int i=0; i<permits; i++ )
38+
availIds.put(i, true)
39+
}
40+
41+
int availablePermits() {
42+
return semaphore.availablePermits()
43+
}
44+
45+
List<Integer> acquire(int permits) {
46+
semaphore.acquire(permits)
47+
final result = new ArrayList<Integer>(permits)
48+
for( final entry : availIds.entrySet() ) {
49+
if( entry.getValue() ) {
50+
entry.setValue(false)
51+
result.add(entry.getKey())
52+
}
53+
if( result.size() == permits )
54+
break
55+
}
56+
return result
57+
}
58+
59+
void release(List<Integer> ids) {
60+
semaphore.release(ids.size())
61+
for( id in ids )
62+
availIds.put(id, true)
63+
}
64+
65+
}

modules/nextflow/src/test/groovy/nextflow/processor/LocalPollingMonitorTest.groovy

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory
2121
import com.sun.management.OperatingSystemMXBean
2222
import nextflow.Session
2323
import nextflow.exception.ProcessUnrecoverableException
24+
import nextflow.executor.local.LocalTaskHandler
2425
import nextflow.util.MemoryUnit
2526
import spock.lang.Specification
2627
/**
@@ -38,14 +39,15 @@ class LocalPollingMonitorTest extends Specification {
3839
cpus: 10,
3940
capacity: 20,
4041
memory: _20_GB,
42+
gpus: 0,
4143
session: session,
4244
name: 'local',
4345
pollInterval: 100
4446
)
4547

4648
def task = new TaskRun()
4749
task.config = new TaskConfig(cpus: 3, memory: MemoryUnit.of('2GB'))
48-
def handler = Mock(TaskHandler)
50+
def handler = Mock(LocalTaskHandler)
4951
handler.getTask() >> { task }
5052

5153
expect:
@@ -86,14 +88,15 @@ class LocalPollingMonitorTest extends Specification {
8688
cpus: 10,
8789
capacity: 10,
8890
memory: _20_GB,
91+
gpus: 0,
8992
session: session,
9093
name: 'local',
9194
pollInterval: 100
9295
)
9396

9497
def task = new TaskRun()
9598
task.config = new TaskConfig(cpus: 4, memory: MemoryUnit.of('8GB'))
96-
def handler = Mock(TaskHandler)
99+
def handler = Mock(LocalTaskHandler)
97100
handler.getTask() >> { task }
98101
handler.canForkProcess() >> true
99102
handler.isReady() >> true
@@ -132,14 +135,15 @@ class LocalPollingMonitorTest extends Specification {
132135
cpus: 1,
133136
capacity: 1,
134137
memory: _20_GB,
138+
gpus: 0,
135139
session: session,
136140
name: 'local',
137141
pollInterval: 100
138142
)
139143

140144
def task = new TaskRun()
141145
task.config = new TaskConfig(cpus: 1, memory: MemoryUnit.of('8GB'))
142-
def handler = Mock(TaskHandler)
146+
def handler = Mock(LocalTaskHandler)
143147
handler.getTask() >> { task }
144148
handler.canForkProcess() >> true
145149
handler.isReady() >> true
@@ -167,6 +171,7 @@ class LocalPollingMonitorTest extends Specification {
167171
cpus: 10,
168172
capacity: 20,
169173
memory: _20_GB,
174+
gpus: 0,
170175
session: session,
171176
name: 'local',
172177
pollInterval: 100
@@ -195,6 +200,7 @@ class LocalPollingMonitorTest extends Specification {
195200
cpus: 10,
196201
capacity: 20,
197202
memory: _20_GB,
203+
gpus: 0,
198204
session: session,
199205
name: 'local',
200206
pollInterval: 100

0 commit comments

Comments
 (0)