Skip to content

Commit 76790d2

Browse files
Allow Azure Batch tasks to be submitted to different pools (#5766) [ci fast]
Signed-off-by: adamrtalbot <[email protected]> Signed-off-by: Paolo Di Tommaso <[email protected]> Co-authored-by: Paolo Di Tommaso <[email protected]> Signed-off-by: Paolo Di Tommaso <[email protected]>
1 parent b163da9 commit 76790d2

File tree

4 files changed

+145
-5
lines changed

4 files changed

+145
-5
lines changed

plugins/nf-azure/src/main/nextflow/cloud/azure/batch/AzBatchService.groovy

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ import nextflow.cloud.types.CloudMachineInfo
8888
import nextflow.cloud.types.PriceModel
8989
import nextflow.fusion.FusionHelper
9090
import nextflow.fusion.FusionScriptLauncher
91-
import nextflow.processor.TaskProcessor
9291
import nextflow.processor.TaskRun
9392
import nextflow.util.CacheHelper
9493
import nextflow.util.MemoryUnit
@@ -111,7 +110,7 @@ class AzBatchService implements Closeable {
111110

112111
AzConfig config
113112

114-
Map<TaskProcessor,String> allJobIds = new HashMap<>(50)
113+
Map<AzJobKey,String> allJobIds = new HashMap<>(50)
115114

116115
AzBatchService(AzBatchExecutor executor) {
117116
assert executor
@@ -355,16 +354,26 @@ class AzBatchService implements Closeable {
355354
}
356355

357356
synchronized String getOrCreateJob(String poolId, TaskRun task) {
358-
final mapKey = task.processor
357+
// Use the same job Id for the same Process,PoolId pair
358+
// The Pool is added to allow using different queue names (corresponding
359+
// a pool id) for the same process. See also
360+
// https://github.com/nextflow-io/nextflow/pull/5766
361+
final mapKey = new AzJobKey(task.processor, poolId)
359362
if( allJobIds.containsKey(mapKey)) {
360363
return allJobIds[mapKey]
361364
}
365+
final jobId = createJob0(poolId,task)
366+
// add to the map
367+
allJobIds[mapKey] = jobId
368+
return jobId
369+
}
370+
371+
protected String createJob0(String poolId, TaskRun task) {
372+
log.debug "[AZURE BATCH] created job for ${task.processor.name} with pool ${poolId}"
362373
// create a batch job
363374
final jobId = makeJobId(task)
364375
final content = new BatchJobCreateContent(jobId, new BatchPoolInfo(poolId: poolId))
365376
apply(() -> client.createJob(content))
366-
// add to the map
367-
allJobIds[mapKey] = jobId
368377
return jobId
369378
}
370379

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
package nextflow.cloud.azure.batch
19+
20+
import groovy.transform.Canonical
21+
import groovy.transform.CompileStatic
22+
import nextflow.processor.TaskProcessor
23+
/**
24+
* Model a Batch job key for caching purposes
25+
*
26+
* @author Paolo Di Tommaso <[email protected]>
27+
*/
28+
@Canonical
29+
@CompileStatic
30+
class AzJobKey {
31+
final TaskProcessor processor
32+
final String poolId
33+
}

plugins/nf-azure/src/test/nextflow/cloud/azure/batch/AzBatchServiceTest.groovy

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,4 +739,54 @@ class AzBatchServiceTest extends Specification {
739739
[managedIdentity: [clientId: 'client-123']] | 'client-123'
740740
}
741741

742+
def 'should cache job id' () {
743+
given:
744+
def exec = Mock(AzBatchExecutor)
745+
def service = Spy(new AzBatchService(exec))
746+
and:
747+
def p1 = Mock(TaskProcessor)
748+
def p2 = Mock(TaskProcessor)
749+
def t1 = Mock(TaskRun) { getProcessor()>>p1 }
750+
def t2 = Mock(TaskRun) { getProcessor()>>p2 }
751+
def t3 = Mock(TaskRun) { getProcessor()>>p2 }
752+
753+
when:
754+
def result = service.getOrCreateJob('foo',t1)
755+
then:
756+
1 * service.createJob0('foo',t1) >> 'job1'
757+
and:
758+
result == 'job1'
759+
760+
// second time is cached
761+
when:
762+
result = service.getOrCreateJob('foo',t1)
763+
then:
764+
0 * service.createJob0('foo',t1) >> null
765+
and:
766+
result == 'job1'
767+
768+
// changing pool id returns a new job id
769+
when:
770+
result = service.getOrCreateJob('bar',t1)
771+
then:
772+
1 * service.createJob0('bar',t1) >> 'job2'
773+
and:
774+
result == 'job2'
775+
776+
// changing process returns a new job id
777+
when:
778+
result = service.getOrCreateJob('bar',t2)
779+
then:
780+
1 * service.createJob0('bar',t2) >> 'job3'
781+
and:
782+
result == 'job3'
783+
784+
// change task with the same process, return cached job id
785+
when:
786+
result = service.getOrCreateJob('bar',t3)
787+
then:
788+
0 * service.createJob0('bar',t3) >> null
789+
and:
790+
result == 'job3'
791+
}
742792
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
*/
17+
18+
package nextflow.cloud.azure.batch
19+
20+
import nextflow.processor.TaskProcessor
21+
import spock.lang.Specification
22+
23+
/**
24+
*
25+
* @author Paolo Di Tommaso <[email protected]>
26+
*/
27+
class AzJobKeyTest extends Specification {
28+
29+
def 'should validate equals and hashcode' () {
30+
given:
31+
def p1 = Mock(TaskProcessor)
32+
def p2 = Mock(TaskProcessor)
33+
def k1 = new AzJobKey(p1, 'foo')
34+
def k2 = new AzJobKey(p1, 'foo')
35+
def k3 = new AzJobKey(p2, 'foo')
36+
def k4 = new AzJobKey(p1, 'bar')
37+
38+
expect:
39+
k1 == k2
40+
k1 != k3
41+
k1 != k4
42+
and:
43+
k1.hashCode() == k2.hashCode()
44+
k1.hashCode() != k3.hashCode()
45+
k1.hashCode() != k4.hashCode()
46+
}
47+
48+
}

0 commit comments

Comments
 (0)