-
Notifications
You must be signed in to change notification settings - Fork 773
Expand file tree
/
Copy pathSeqeraTaskHandler.groovy
More file actions
321 lines (288 loc) · 11.8 KB
/
SeqeraTaskHandler.groovy
File metadata and controls
321 lines (288 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
/*
* Copyright 2013-2025, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.seqera.executor
import java.nio.file.Path
import groovy.transform.CompileStatic
import groovy.transform.PackageScope
import groovy.util.logging.Slf4j
import io.seqera.sched.api.schema.v1a1.AcceleratorType
import io.seqera.sched.api.schema.v1a1.GetTaskLogsResponse
import io.seqera.sched.api.schema.v1a1.ResourceRequirement
import io.seqera.sched.api.schema.v1a1.Task
import io.seqera.sched.api.schema.v1a1.TaskState as SchedTaskState
import io.seqera.sched.api.schema.v1a1.TaskStatus as SchedTaskStatus
import io.seqera.sched.client.SchedClient
import io.seqera.util.MapperUtil
import nextflow.cloud.types.CloudMachineInfo
import nextflow.exception.ProcessException
import nextflow.exception.ProcessUnrecoverableException
import nextflow.fusion.FusionAwareTask
import nextflow.processor.TaskHandler
import nextflow.processor.TaskRun
import nextflow.processor.TaskStatus
import nextflow.trace.TraceRecord
/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
class SeqeraTaskHandler extends TaskHandler implements FusionAwareTask {
private SchedClient client
private SeqeraExecutor executor
private Path exitFile
private Path outputFile
private Path errorFile
private volatile String taskId
/**
* Cached task state from last describeTask call, used for trace record metadata
*/
private volatile SchedTaskState cachedTaskState
/**
* Cached machine info extracted from task attempts
*/
private volatile CloudMachineInfo machineInfo
SeqeraTaskHandler(TaskRun task, SeqeraExecutor executor) {
super(task)
this.client = executor.getClient()
this.executor = executor
// those files are access via NF runtime, keep based on CloudStoragePath
this.outputFile = task.workDir.resolve(TaskRun.CMD_OUTFILE)
this.errorFile = task.workDir.resolve(TaskRun.CMD_ERRFILE)
this.exitFile = task.workDir.resolve(TaskRun.CMD_EXIT)
}
@Override
void prepareLauncher() {
assert fusionEnabled()
final launcher = fusionLauncher()
launcher.build()
}
@Override
void submit() {
int cpuShares = (task.config.getCpus() ?: 1) * 1024
int memoryMiB = task.config.getMemory() ? (int) (task.config.getMemory().toBytes() / (1024 * 1024)) : 1024
final resourceReq = new ResourceRequirement()
.cpuShares(cpuShares)
.memoryMiB(memoryMiB)
// add accelerator settings if defined
final accelerator = task.config.getAccelerator()
if( accelerator ) {
// number of accelerators requested, fallback to limit if request is not specified
resourceReq.acceleratorCount(accelerator.request ?: accelerator.limit)
// accelerator type is GPU by default (most common in scientific computing)
resourceReq.acceleratorType(AcceleratorType.GPU)
// specific accelerator model name e.g. "nvidia-tesla-v100", "nvidia-a10g"
if( accelerator.type )
resourceReq.acceleratorName(accelerator.type)
}
// build machine requirement merging config settings with task arch, disk, and snapshot settings
final machineReq = MapperUtil.toMachineRequirement(
executor.getSeqeraConfig().machineRequirement,
task.getContainerPlatform(),
task.config.getDisk(),
fusionConfig().snapshotsEnabled()
)
// validate container - Seqera executor requires all processes to specify a container image
final container = task.getContainer()
if( !container )
throw new ProcessUnrecoverableException("Process `${task.lazyName()}` failed because the container image was not specified -- the Seqera executor requires all processes define a container image")
// build the scheduler task with all required attributes
final schedTask = new Task()
.name(task.lazyName()) // process name for identification
.image(container) // container image to run
.command(fusionSubmitCli()) // fusion-based command launcher
.environment(fusionLauncher().fusionEnv()) // fusion environment variables
.resourceRequirement(resourceReq) // cpu, memory, accelerators
.workDir(task.getWorkDirStr()) // task working directory
.machineRequirement(machineReq) // machine type and disk requirements
log.debug "[SEQERA] Enqueueing task for batch submission: ${schedTask}"
// Enqueue for batch submission - status will be set by setBatchTaskId callback
executor.getBatchSubmitter().submit(this, schedTask)
}
/**
* Called by batch submitter after successful batch submission
*/
void setBatchTaskId(String taskId) {
this.taskId = taskId
this.status = TaskStatus.SUBMITTED
log.debug "[SEQERA] Process `${task.lazyName()}` submitted > taskId=$taskId; work-dir=${task.getWorkDirStr()}"
}
/**
* Called by batch submitter when batch submission fails
*/
void onBatchSubmitFailure(Exception cause) {
log.debug "[SEQERA] Batch submission failed for task ${task.lazyName()}: ${cause.message}"
task.error = cause
this.status = TaskStatus.COMPLETED
}
protected SchedTaskStatus schedTaskStatus() {
cachedTaskState = client.describeTask(taskId).getTaskState()
return cachedTaskState.getStatus()
}
@Override
boolean checkIfRunning() {
if (isSubmitted()) {
final schedStatus = schedTaskStatus()
log.debug "[SEQERA] checkIfRunning taskId=${taskId}; status=${schedStatus}"
if (isRunningOrTerminated(schedStatus)) {
status = TaskStatus.RUNNING
return true
}
}
return false
}
@Override
boolean checkIfCompleted() {
// Handle batch submission failure - task error was set but never reached RUNNING state
if (task.error && isCompleted()) {
return true
}
if (!isRunning())
return false
final schedStatus = schedTaskStatus()
log.debug "[SEQERA] checkIfCompleted status=${schedStatus}"
if (isTerminated(schedStatus)) {
log.debug "[SEQERA] Process `${task.lazyName()}` - terminated taskId=$taskId; status=$schedStatus"
// finalize the task
task.exitStatus = readExitFile()
if (isFailed(schedStatus)) {
// When no exit code available, get the error message from task state
if (task.exitStatus == Integer.MAX_VALUE) {
final errorMessage = cachedTaskState?.getErrorMessage() ?: "Task failed for unknown reason"
task.error = new ProcessException(errorMessage)
}
final logs = getTaskLogs(taskId)
task.stdout = logs?.stdout ?: outputFile
task.stderr = logs?.stderr ?: errorFile
} else {
task.stdout = outputFile
task.stderr = errorFile
}
status = TaskStatus.COMPLETED
return true
}
return false
}
protected boolean isRunningOrTerminated(SchedTaskStatus status) {
return status == SchedTaskStatus.RUNNING || isTerminated(status)
}
protected boolean isTerminated(SchedTaskStatus status) {
return status in [SchedTaskStatus.SUCCEEDED, SchedTaskStatus.FAILED, SchedTaskStatus.CANCELLED]
}
protected boolean isFailed(SchedTaskStatus status) {
return status == SchedTaskStatus.FAILED
}
protected GetTaskLogsResponse getTaskLogs(String taskId) {
return client.getTaskLogs(taskId)
}
@Override
protected void killTask() {
if( !taskId ) {
log.trace "[SEQERA] Skip cancel - taskId not yet assigned"
return
}
log.debug "[SEQERA] Cancel taskId=${taskId}"
try {
client.cancelTask(taskId)
}
catch (Throwable t) {
log.warn "[SEQERA] Failed to cancel task ${taskId}", t
}
}
@PackageScope
Integer readExitFile() {
try {
final result = exitFile.text as Integer
log.trace "[SEQERA] Read exit file for taskId $taskId; exit=${result}"
return result
}
catch (Exception e) {
log.debug "[SEQERA] Cannot read exit status for task: `${task.lazyName()}` - ${e.message}"
// return MAX_VALUE to signal it was unable to retrieve the exit code
return Integer.MAX_VALUE
}
}
/**
* Get machine info for the task execution from the last task attempt.
* The machine info is cached after first retrieval.
*
* @return CloudMachineInfo containing instance type, zone, and price model, or null if not available
*/
protected CloudMachineInfo getMachineInfo() {
if (machineInfo)
return machineInfo
if (!cachedTaskState)
return null
try {
final attempts = cachedTaskState.getAttempts()
if (!attempts || attempts.isEmpty())
return null
final lastAttempt = attempts.get(attempts.size() - 1)
final lastInfo = lastAttempt.getMachineInfo()
if (!lastInfo)
return null
// Convert Sched API MachineInfo to Nextflow CloudMachineInfo
machineInfo = new CloudMachineInfo(
type: lastInfo.getType(),
zone: lastInfo.getZone(),
priceModel: MapperUtil.toPriceModel(lastInfo.getPriceModel())
)
log.trace "[SEQERA] taskId=$taskId => machineInfo=$machineInfo"
return machineInfo
}
catch (Exception e) {
log.debug "[SEQERA] Unable to get machine info for taskId=$taskId - ${e.message}"
return null
}
}
/**
* Get the number of spot interruptions for this task.
* This is calculated server-side from task attempts with spot-related stop reasons.
*
* @return the count of spot interruptions, or null if not completed or not available
*/
protected Integer getNumSpotInterruptions() {
if (!taskId || !isCompleted())
return null
if (!cachedTaskState)
return null
return cachedTaskState.getNumSpotInterruptions()
}
/**
* Get the native backend ID for this task (ECS task ARN or Docker container ID).
*
* @return the native ID from the last task attempt, or null if not available
*/
protected String getNativeId() {
return cachedTaskState?.getId()
}
/**
* Get the trace record for this task, including machine info and spot interruptions metadata.
*
* @return the trace record with additional metadata fields
*/
@Override
TraceRecord getTraceRecord() {
final result = super.getTraceRecord()
result.put('native_id', getNativeId())
result.machineInfo = getMachineInfo()
result.numSpotInterruptions = getNumSpotInterruptions()
// Override executor name to include cloud backend for cost tracking
result.executorName = "${SeqeraExecutor.SEQERA}/aws"
return result
}
}