diff --git a/plugins/nf-azure/src/main/nextflow/cloud/azure/batch/AzBatchService.groovy b/plugins/nf-azure/src/main/nextflow/cloud/azure/batch/AzBatchService.groovy index 5d9f9f1596..9066ab1587 100644 --- a/plugins/nf-azure/src/main/nextflow/cloud/azure/batch/AzBatchService.groovy +++ b/plugins/nf-azure/src/main/nextflow/cloud/azure/batch/AzBatchService.groovy @@ -35,8 +35,10 @@ import com.azure.compute.batch.models.BatchJobCreateContent import com.azure.compute.batch.models.BatchJobConstraints import com.azure.compute.batch.models.BatchJobUpdateContent import com.azure.compute.batch.models.BatchNodeFillType +import com.azure.compute.batch.models.BatchNodeIdentityReference import com.azure.compute.batch.models.BatchPool import com.azure.compute.batch.models.BatchPoolCreateContent +import com.azure.compute.batch.models.BatchPoolIdentity import com.azure.compute.batch.models.BatchPoolInfo import com.azure.compute.batch.models.BatchPoolState import com.azure.compute.batch.models.BatchStartTask @@ -60,6 +62,7 @@ import com.azure.compute.batch.models.OutputFileDestination import com.azure.compute.batch.models.OutputFileUploadCondition import com.azure.compute.batch.models.OutputFileUploadConfig import com.azure.compute.batch.models.ResourceFile +import com.azure.compute.batch.models.UserAssignedIdentity import com.azure.compute.batch.models.UserIdentity import com.azure.compute.batch.models.VirtualMachineConfiguration import com.azure.core.credential.AzureNamedKeyCredential @@ -113,6 +116,8 @@ class AzBatchService implements Closeable { static private final long _1GB = 1 << 30 static final private Map allPools = new HashMap<>(50) + + static final private Map poolManagedIdentityIds = new HashMap<>(50) AzConfig config @@ -474,6 +479,7 @@ class AzBatchService implements Closeable { assert jobId, 'Missing Azure Batch jobId argument' assert task, 'Missing Azure Batch task argument' + // SAS token is always required for output file uploads via azcopy final sas = config.storage().sasToken if( !sas ) throw new IllegalArgumentException("Missing Azure Blob storage SAS token") @@ -545,11 +551,20 @@ class AzBatchService implements Closeable { final constraints = taskConstraints(task) log.trace "[AZURE BATCH] Submitting task: $taskId, cpus=${task.config.getCpus()}, mem=${task.config.getMemory()?:'-'}, slots: $slots" + + // Check if we should use managed identity and identify the resource ID (ARM) + final poolIdentityClientId = config.batch().poolIdentityClientId + final poolManagedIdentityResourceId = poolIdentityClientId ? getPoolManagedIdentityResourceId(poolId, poolIdentityClientId) : null + if( poolIdentityClientId && !poolManagedIdentityResourceId ) { + // Throw a warning if we are trying to use managed identity and can't locate it on the pool + log.warn "[AZURE BATCH] No managed identity found for pool '$poolId' with client ID '${poolIdentityClientId}'. Falling back to SAS token authentication." + } + return new BatchTaskCreateContent(taskId, cmd) .setUserIdentity(userIdentity(pool.opts.privileged, pool.opts.runAs, AutoUserScope.TASK)) .setContainerSettings(containerOpts) - .setResourceFiles(resourceFileUrls(task, sas)) - .setOutputFiles(outputFileUrls(task, sas)) + .setResourceFiles(resourceFileUrls(task, poolManagedIdentityResourceId, sas)) + .setOutputFiles(outputFileUrls(task, poolManagedIdentityResourceId, sas)) .setRequiredSlots(slots) .setConstraints(constraints) } @@ -591,7 +606,7 @@ class AzBatchService implements Closeable { return result } - protected List resourceFileUrls(TaskRun task, String sas) { + protected List resourceFileUrls(TaskRun task, String poolManagedIdentityResourceId, String sas) { final cmdRun = (AzPath) task.workDir.resolve(TaskRun.CMD_RUN) final cmdScript = (AzPath) task.workDir.resolve(TaskRun.CMD_SCRIPT) @@ -604,44 +619,94 @@ class AzBatchService implements Closeable { .setFilePath('.nextflow-bin/azcopy') } - resFiles << new ResourceFile() - .setHttpUrl(AzHelper.toHttpUrl(cmdRun, sas)) - .setFilePath(TaskRun.CMD_RUN) + // Create resource files with or without managed identity + if( poolManagedIdentityResourceId ) { + // When using managed identity, create BatchNodeIdentityReference + // For pool-level managed identity, we create an empty reference which will use the pool's identity + // The poolIdentityClientId configuration ensures the pool has been configured with a managed identity + // Azure Batch will automatically use that identity when downloading these resource files + // Create identity reference with the resource ID from the pool + final identityRef = new BatchNodeIdentityReference() + .setResourceId(poolManagedIdentityResourceId) + log.debug "[AZURE BATCH] Using managed identity with resource ID: ${poolManagedIdentityResourceId}" + + resFiles << new ResourceFile() + .setHttpUrl(AzHelper.toHttpUrl(cmdRun, null)) + .setFilePath(TaskRun.CMD_RUN) + .setIdentityReference(identityRef) - resFiles << new ResourceFile() - .setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas)) - .setFilePath(TaskRun.CMD_SCRIPT) + resFiles << new ResourceFile() + .setHttpUrl(AzHelper.toHttpUrl(cmdScript, null)) + .setFilePath(TaskRun.CMD_SCRIPT) + .setIdentityReference(identityRef) + + if( task.stdin ) { + resFiles << new ResourceFile() + .setHttpUrl(AzHelper.toHttpUrl(cmdScript, null)) + .setFilePath(TaskRun.CMD_INFILE) + .setIdentityReference(identityRef) + } + } + else { + // Use traditional SAS token approach + resFiles << new ResourceFile() + .setHttpUrl(AzHelper.toHttpUrl(cmdRun, sas)) + .setFilePath(TaskRun.CMD_RUN) - if( task.stdin ) { resFiles << new ResourceFile() .setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas)) - .setFilePath(TaskRun.CMD_INFILE) + .setFilePath(TaskRun.CMD_SCRIPT) + + if( task.stdin ) { + resFiles << new ResourceFile() + .setHttpUrl(AzHelper.toHttpUrl(cmdScript, sas)) + .setFilePath(TaskRun.CMD_INFILE) + } } return resFiles } - protected List outputFileUrls(TaskRun task, String sas) { + protected List outputFileUrls(TaskRun task, String poolManagedIdentityResourceId, String sas) { List result = new ArrayList<>(20) - result << destFile(TaskRun.CMD_EXIT, task.workDir, sas) - result << destFile(TaskRun.CMD_LOG, task.workDir, sas) - result << destFile(TaskRun.CMD_OUTFILE, task.workDir, sas) - result << destFile(TaskRun.CMD_ERRFILE, task.workDir, sas) - result << destFile(TaskRun.CMD_SCRIPT, task.workDir, sas) - result << destFile(TaskRun.CMD_RUN, task.workDir, sas) - result << destFile(TaskRun.CMD_STAGE, task.workDir, sas) - result << destFile(TaskRun.CMD_TRACE, task.workDir, sas) - result << destFile(TaskRun.CMD_ENV, task.workDir, sas) + result << destFile(TaskRun.CMD_EXIT, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_LOG, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_OUTFILE, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_ERRFILE, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_SCRIPT, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_RUN, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_STAGE, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_TRACE, task.workDir, poolManagedIdentityResourceId, sas) + result << destFile(TaskRun.CMD_ENV, task.workDir, poolManagedIdentityResourceId, sas) return result } - protected OutputFile destFile(String localPath, Path targetDir, String sas) { + protected OutputFile destFile(String localPath, Path targetDir, String poolManagedIdentityResourceId, String sas) { log.debug "Task output path: $localPath -> ${targetDir.toUriString()}" - def target = targetDir.resolve(localPath) - final dest = new OutputFileBlobContainerDestination(AzHelper.toContainerUrl(targetDir,sas)) - .setPath(target.subpath(1,target.nameCount).toString()) - - return new OutputFile(localPath, new OutputFileDestination().setContainer(dest), new OutputFileUploadConfig(OutputFileUploadCondition.TASK_COMPLETION)) + + // Calculate the target blob path + def targetPath = targetDir.resolve(localPath) + def blobPath = targetPath.subpath(1, targetPath.nameCount).toString() + + // Create the destination with appropriate authentication + def containerUrl = AzHelper.toContainerUrl(targetDir, poolManagedIdentityResourceId ? null : sas) + def destination = new OutputFileBlobContainerDestination(containerUrl) + .setPath(blobPath) + + // Add identity reference if using managed identity + if( poolManagedIdentityResourceId ) { + log.debug "[AZURE BATCH] Setting identity reference for $localPath with resource ID: $poolManagedIdentityResourceId" + destination.setIdentityReference( + new BatchNodeIdentityReference().setResourceId(poolManagedIdentityResourceId) + ) + } + + // Create and return the output file configuration + return new OutputFile( + localPath, + new OutputFileDestination().setContainer(destination), + new OutputFileUploadConfig(OutputFileUploadCondition.TASK_COMPLETION) + ) } protected BatchSupportedImage getImage(AzPoolOpts opts) { @@ -744,6 +809,63 @@ class AzBatchService implements Closeable { } + /** + * Get the managed identity resource ID from the pool if available + * @param poolId The pool ID to check + * @param poolIdentityClientId Can be 'auto', a specific client ID, null or false + * @return The resource ID of the managed identity, or null if not found/configured + */ + protected String getPoolManagedIdentityResourceId(String poolId, poolIdentityClientId) { + // If poolIdentityClientId is null or false, return null + if( !poolIdentityClientId ) { + return null + } + + + // TODO: Should we throw an error if we can't find the identity attached to the pool? + + try { + def pool = getPool(poolId) + if( !pool ) { + return null + } + + def poolIdentity = pool.getIdentity() + if( !poolIdentity ) { + return null + } + List identities = poolIdentity?.getUserAssignedIdentities() + + if( !identities || identities.isEmpty() ) { + return null + } + + // Handle 'auto' - use the first available identity + if( poolIdentityClientId == 'auto' || poolIdentityClientId == true ) { + def firstIdentity = identities.first() + log.debug "[AZURE BATCH] Using managed identity for pool '$poolId'" + return firstIdentity.getResourceId() + } + + // Handle specific client ID + if( poolIdentityClientId instanceof String ) { + def matchingIdentity = identities.find { it.getClientId() == poolIdentityClientId } + if( matchingIdentity ) { + log.debug "[AZURE BATCH] Found managed identity for pool '$poolId'" + return matchingIdentity.getResourceId() + } + return null + } + + // Unsupported type + return null + } + catch( Exception e ) { + log.debug "[AZURE BATCH] Error getting managed identity for pool '$poolId': ${e.message}" + return null + } + } + synchronized String getOrCreatePool(TaskRun task) { final spec = specForTask(task) @@ -1108,3 +1230,4 @@ class AzBatchService implements Closeable { return Failsafe.with(policy).get(action) } } + diff --git a/plugins/nf-azure/src/test/nextflow/cloud/azure/batch/AzBatchServiceTest.groovy b/plugins/nf-azure/src/test/nextflow/cloud/azure/batch/AzBatchServiceTest.groovy index 353383dcd4..632000e25d 100644 --- a/plugins/nf-azure/src/test/nextflow/cloud/azure/batch/AzBatchServiceTest.groovy +++ b/plugins/nf-azure/src/test/nextflow/cloud/azure/batch/AzBatchServiceTest.groovy @@ -2,16 +2,27 @@ package nextflow.cloud.azure.batch import java.nio.ByteBuffer import java.nio.charset.StandardCharsets +import java.nio.file.Path +import java.nio.file.Paths +import java.nio.file.attribute.BasicFileAttributes +import java.nio.file.spi.FileSystemProvider import java.time.Instant import java.time.temporal.ChronoUnit import java.util.function.Predicate +import com.azure.compute.batch.models.BatchNodeIdentityReference import com.azure.compute.batch.models.BatchPool +import com.azure.compute.batch.models.BatchPoolIdentity import com.azure.compute.batch.models.ElevationLevel +import com.azure.compute.batch.models.UserAssignedIdentity import com.azure.compute.batch.models.EnvironmentSetting +import com.azure.compute.batch.models.OutputFile +import com.azure.compute.batch.models.OutputFileUploadCondition import com.azure.core.exception.HttpResponseException import com.azure.core.http.HttpResponse import com.azure.identity.ManagedIdentityCredential +import com.azure.storage.blob.BlobClient +import com.azure.storage.blob.BlobContainerClient import com.google.common.hash.HashCode import nextflow.Global import nextflow.Session @@ -21,6 +32,9 @@ import nextflow.cloud.azure.config.AzConfig import nextflow.cloud.azure.config.AzManagedIdentityOpts import nextflow.cloud.azure.config.AzPoolOpts import nextflow.cloud.azure.config.AzStartTaskOpts +import nextflow.cloud.azure.nio.AzFileSystem +import nextflow.cloud.azure.nio.AzFileSystemProvider +import nextflow.cloud.azure.nio.AzPath import nextflow.file.FileSystemPathFactory import nextflow.processor.TaskBean import nextflow.processor.TaskConfig @@ -31,6 +45,7 @@ import nextflow.util.MemoryUnit import reactor.core.publisher.Flux import spock.lang.Specification import spock.lang.Unroll +import spock.lang.Ignore /** * * @author Paolo Di Tommaso @@ -39,6 +54,59 @@ class AzBatchServiceTest extends Specification { static long _1GB = 1024 * 1024 * 1024 + protected Path mockAzPath(String path, boolean isDir=false) { + assert path.startsWith('az://') + + def tokens = path.tokenize('/') + def bucket = tokens[1] + def file = '/' + tokens[2..-1].join('/') + + def attr = Mock(BasicFileAttributes) + attr.isDirectory() >> isDir + attr.isRegularFile() >> !isDir + attr.isSymbolicLink() >> false + + def provider = Mock(AzFileSystemProvider) + provider.getScheme() >> 'az' + provider.readAttributes(_, _, _) >> attr + + def fs = Mock(AzFileSystem) + fs.provider() >> provider + fs.toString() >> ('az://' + bucket) + def uri = GroovyMock(URI) + uri.toString() >> path + + def BLOB_CLIENT = GroovyMock(BlobClient) { + getBlobUrl() >> { path.replace('az://', 'http://account.blob.core.windows.net/') } + } + + def CONTAINER_CLIENT = GroovyMock(BlobContainerClient) { + getBlobContainerUrl() >> { "http://account.blob.core.windows.net/${bucket}" } + } + + def result = GroovyMock(AzPath) + result.toUriString() >> path + result.toString() >> file + result.getFileSystem() >> fs + result.toUri() >> uri + result.resolve(_) >> { mockAzPath("$path/${it[0]}") } + result.toAbsolutePath() >> result + result.asBoolean() >> true + result.getParent() >> { def p=path.lastIndexOf('/'); p!=-1 ? mockAzPath("${path.substring(0,p)}", true) : null } + result.getFileName() >> { Paths.get(tokens[-1]) } + result.getName() >> tokens[1] + result.getContainerName() >> bucket + result.blobName() >> file.substring(1) // Remove leading slash + result.blobClient() >> BLOB_CLIENT + result.containerClient() >> CONTAINER_CLIENT + result.getNameCount() >> tokens.size() - 2 // Exclude 'az://' and bucket + result.subpath(_, _) >> { int start, int end -> + def pathTokens = file.substring(1).split('/') + Paths.get(pathTokens[start..> HashCode.fromInt(1) getContainer() >> 'ubuntu:latest' getConfig() >> Mock(TaskConfig) + getWorkDir() >> mockAzPath('az://container/work/dir') } and: def SPEC = new AzVmPoolSpec(poolId: POOL_ID, vmType: Mock(AzVmType), opts: new AzPoolOpts([:])) @@ -668,8 +737,8 @@ class AzBatchServiceTest extends Specification { then: 1 * azure.getPoolSpec(POOL_ID) >> SPEC 1 * azure.computeSlots(TASK, SPEC) >> 4 - 1 * azure.resourceFileUrls(TASK, SAS) >> [] - 1 * azure.outputFileUrls(TASK, SAS) >> [] + 1 * azure.resourceFileUrls(TASK, null, SAS) >> [] + 1 * azure.outputFileUrls(TASK, null, SAS) >> [] and: result.id == 'nf-01000000' result.requiredSlots == 4 @@ -702,7 +771,7 @@ class AzBatchServiceTest extends Specification { getCpus() >> 4 getMemory() >> MemoryUnit.of('8 GB') } - + getWorkDir() >> mockAzPath('az://container/work/dir') } and: def SPEC = new AzVmPoolSpec(poolId: POOL_ID, vmType: Mock(AzVmType), opts: new AzPoolOpts([:])) @@ -712,8 +781,8 @@ class AzBatchServiceTest extends Specification { then: 1 * azure.getPoolSpec(POOL_ID) >> SPEC 1 * azure.computeSlots(TASK, SPEC) >> 4 - 1 * azure.resourceFileUrls(TASK, SAS) >> [] - 1 * azure.outputFileUrls(TASK, SAS) >> [] + 1 * azure.resourceFileUrls(TASK, null, SAS) >> [] + 1 * azure.outputFileUrls(TASK, null, SAS) >> [] and: result.id == 'nf-02000000' result.requiredSlots == 4 @@ -747,7 +816,7 @@ class AzBatchServiceTest extends Specification { getContainerOptions() >> '-v /foo:/foo' getTime() >> Duration.of('24 h') } - + getWorkDir() >> mockAzPath('az://container/work/dir') } and: def SPEC = new AzVmPoolSpec(poolId: POOL_ID, vmType: Mock(AzVmType), opts: new AzPoolOpts([:])) @@ -757,8 +826,8 @@ class AzBatchServiceTest extends Specification { then: 1 * azure.getPoolSpec(POOL_ID) >> SPEC 1 * azure.computeSlots(TASK, SPEC) >> 4 - 1 * azure.resourceFileUrls(TASK, SAS) >> [] - 1 * azure.outputFileUrls(TASK, SAS) >> [] + 1 * azure.resourceFileUrls(TASK, null, SAS) >> [] + 1 * azure.outputFileUrls(TASK, null, SAS) >> [] and: result.id == 'nf-02000000' result.requiredSlots == 4 @@ -800,8 +869,8 @@ class AzBatchServiceTest extends Specification { then: 1 * azure.getPoolSpec(POOL_ID) >> SPEC 1 * azure.computeSlots(TASK, SPEC) >> 1 - 1 * azure.resourceFileUrls(TASK, SAS) >> [] - 1 * azure.outputFileUrls(TASK, SAS) >> [] + 1 * azure.resourceFileUrls(TASK, null, SAS) >> [] + 1 * azure.outputFileUrls(TASK, null, SAS) >> [] and: result.id == 'nf-01000000' result.requiredSlots == 1 @@ -850,13 +919,229 @@ class AzBatchServiceTest extends Specification { when: def env = [:] as Map - if( service.config.batch().poolIdentityClientId && true ) { // fusionEnabled = true + // This simulates the logic in AzFusionEnv - only set FUSION_AZ_MSI_CLIENT_ID if it's a specific client ID + if( service.config.batch().poolIdentityClientId && service.config.batch().poolIdentityClientId != true ) { env.put('FUSION_AZ_MSI_CLIENT_ID', service.config.batch().poolIdentityClientId) } then: env['FUSION_AZ_MSI_CLIENT_ID'] == POOL_IDENTITY_CLIENT_ID } + + def 'should auto-discover pool identity when poolIdentityClientId is true' () { + given: + def CONFIG = new AzConfig([ + batch: [poolIdentityClientId: true], + storage: [sasToken: 'test-sas-token', accountName: 'testaccount'] + ]) + + and: + Global.session = Mock(Session) { + getConfig() >> [fusion: [enabled: true]] + } + + when: + // Test by simulating the actual behavior + def isFusionEnabled = Global.session?.config?.fusion?.enabled + def env = [:] as Map + + // When poolIdentityClientId is true (auto-discovery), don't set FUSION_AZ_MSI_CLIENT_ID + def poolIdentityClientId = CONFIG.batch().poolIdentityClientId + + // Only set the env var if poolIdentityClientId is a specific client ID + // true (boolean or string), 'auto' means auto-discovery, so don't set the env var + if( poolIdentityClientId && + poolIdentityClientId != true && + poolIdentityClientId != 'true' && + poolIdentityClientId != 'auto' ) { + env['FUSION_AZ_MSI_CLIENT_ID'] = poolIdentityClientId.toString() + } + + then: + // FUSION_AZ_MSI_CLIENT_ID should not be set when poolIdentityClientId is true/auto + !env.containsKey('FUSION_AZ_MSI_CLIENT_ID') + } + + def 'should create task with managed identity for resource files and outputs' () { + given: + Global.session = Mock(Session) { getConfig()>>[:] } + and: + def POOL_ID = 'my-pool' + def POOL_IDENTITY_CLIENT_ID = true // Use auto-discovery + def SAS = 'test-sas-token' + def CONFIG = new AzConfig([ + batch: [poolIdentityClientId: POOL_IDENTITY_CLIENT_ID], + storage: [accountName: 'myaccount', sasToken: SAS] + ]) + def exec = createExecutor(CONFIG) + def azure = Mock(AzBatchService) + and: + def TASK = Mock(TaskRun) { + getHash() >> HashCode.fromInt(1) + getContainer() >> 'ubuntu:latest' + getConfig() >> Mock(TaskConfig) + getWorkDir() >> Paths.get('/work/dir') + } + and: + def SPEC = new AzVmPoolSpec(poolId: POOL_ID, vmType: Mock(AzVmType), opts: new AzPoolOpts([:])) + + when: + // Test the behavior by verifying if managed identity would be used + def identityResourceId = azure.getPoolManagedIdentityResourceId(POOL_ID, POOL_IDENTITY_CLIENT_ID) + + then: + // When poolIdentityClientId is true, this should attempt to get managed identity + 1 * azure.getPoolManagedIdentityResourceId(POOL_ID, true) >> '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity' + identityResourceId == '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity' + } + + def 'should require SAS token when not using managed identity' () { + given: + Global.session = Mock(Session) { getConfig()>>[:] } + and: + def POOL_ID = 'my-pool' + def CONFIG = [ + storage: [accountName: 'myaccount'] // No SAS token and no managed identity + ] + def exec = createExecutor(CONFIG) + AzBatchService azure = Spy(new AzBatchService(exec)) + and: + def TASK = Mock(TaskRun) { + getHash() >> HashCode.fromInt(1) + getContainer() >> 'ubuntu:latest' + getConfig() >> Mock(TaskConfig) + getWorkDir() >> mockAzPath('az://container/work/dir') + } + and: + def SPEC = new AzVmPoolSpec(poolId: POOL_ID, vmType: Mock(AzVmType), opts: new AzPoolOpts([:])) + + when: + azure.createTask(POOL_ID, 'salmon', TASK) + + then: + def e = thrown(IllegalArgumentException) + e.message == "Missing Azure Blob storage SAS token" + } + + def 'should generate resource file URLs with identity reference for managed identity' () { + given: + def CONFIG = new AzConfig([ + batch: [poolIdentityClientId: true, copyToolInstallMode: 'node'], + storage: [accountName: 'myaccount', sasToken: 'test-sas'] + ]) + def exec = createExecutor(CONFIG) + def service = Spy(new AzBatchService(exec)) + and: + def cmdRun = mockAzPath('az://container/work/dir/.command.run') + def cmdScript = mockAzPath('az://container/work/dir/.command.sh') + def workDir = mockAzPath('az://container/work/dir') + and: + def TASK = Mock(TaskRun) { + getWorkDir() >> workDir + getStdin() >> null + } + // Mock the workDir.resolve() calls + workDir.resolve(TaskRun.CMD_RUN) >> cmdRun + workDir.resolve(TaskRun.CMD_SCRIPT) >> cmdScript + and: + def RESOURCE_ID = '/subscriptions/6186e1c7-3402-4fba-ba02-4567f2aeeb94/resourceGroups/rg-joint-jackass/providers/Microsoft.ManagedIdentity/userAssignedIdentities/nextflow-id' + + when: + def result = service.resourceFileUrls(TASK, RESOURCE_ID, 'test-sas') // task, poolManagedIdentityResourceId, SAS + + then: + result.size() == 2 + result[0].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.run' + result[0].filePath == '.command.run' + result[0].identityReference != null // Should have identity reference + result[0].identityReference.resourceId == RESOURCE_ID + + result[1].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.sh' + result[1].filePath == '.command.sh' + result[1].identityReference != null // Should have identity reference + result[1].identityReference.resourceId == RESOURCE_ID + + // URLs should not contain SAS tokens + !result[0].httpUrl.contains('?') + !result[1].httpUrl.contains('?') + } + + def 'should generate resource file URLs with SAS token when not using managed identity' () { + given: + def SAS = 'test-sas-token' + def CONFIG = new AzConfig([ + batch: [copyToolInstallMode: 'node'], + storage: [sasToken: SAS, accountName: 'myaccount'] + ]) + def exec = createExecutor(CONFIG) + def service = new AzBatchService(exec) + and: + def cmdRun = mockAzPath('az://container/work/dir/.command.run') + def cmdScript = mockAzPath('az://container/work/dir/.command.sh') + def workDir = mockAzPath('az://container/work/dir') + and: + def TASK = Mock(TaskRun) { + getWorkDir() >> workDir + getStdin() >> null + } + // Mock the workDir.resolve() calls + workDir.resolve(TaskRun.CMD_RUN) >> cmdRun + workDir.resolve(TaskRun.CMD_SCRIPT) >> cmdScript + + when: + def result = service.resourceFileUrls(TASK, null, SAS) // task, no managed identity, use SAS + + then: + result.size() == 2 + result[0].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.run?test-sas-token' + result[0].filePath == '.command.run' + result[0].identityReference == null // No identity reference when using SAS + + result[1].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.sh?test-sas-token' + result[1].filePath == '.command.sh' + result[1].identityReference == null // No identity reference when using SAS + + // URLs should contain SAS tokens + result[0].httpUrl.contains('?test-sas-token') + result[1].httpUrl.contains('?test-sas-token') + } + + def 'should generate resource file URLs with SAS when using managedIdentity config but no pool identity' () { + given: + def CONFIG = new AzConfig([ + batch: [copyToolInstallMode: 'node'], + managedIdentity: [clientId: 'managed-identity-123'], + storage: [accountName: 'myaccount', sasToken: 'test-sas'] + ]) + def exec = createExecutor(CONFIG) + def service = new AzBatchService(exec) + and: + def cmdRun = mockAzPath('az://container/work/dir/.command.run') + def cmdScript = mockAzPath('az://container/work/dir/.command.sh') + def workDir = mockAzPath('az://container/work/dir') + and: + def TASK = Mock(TaskRun) { + getWorkDir() >> workDir + getStdin() >> null + } + // Mock the workDir.resolve() calls + workDir.resolve(TaskRun.CMD_RUN) >> cmdRun + workDir.resolve(TaskRun.CMD_SCRIPT) >> cmdScript + + when: + def result = service.resourceFileUrls(TASK, null, 'test-sas') // task, no managed identity, use SAS + + then: + result.size() == 2 + // Without pool identity, falls back to SAS token auth + result[0].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.run?test-sas' + result[0].filePath == '.command.run' + result[0].identityReference == null + + result[1].httpUrl == 'http://account.blob.core.windows.net/container/work/dir/.command.sh?test-sas' + result[1].filePath == '.command.sh' + result[1].identityReference == null + } def 'should cache job id' () { @@ -1022,4 +1307,225 @@ class AzBatchServiceTest extends Specification { ] } + def 'should get pool managed identity resource ID with null or false'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + + when: 'poolIdentityClientId is null' + def result = service.getPoolManagedIdentityResourceId(poolId, null) + then: + result == null + 0 * service.getPool(poolId) // Should not even call getPool + + when: 'poolIdentityClientId is false' + result = service.getPoolManagedIdentityResourceId(poolId, false) + then: + result == null + 0 * service.getPool(poolId) // Should not even call getPool + } + + @Ignore("Mocking Azure SDK classes (BatchPool, BatchPoolIdentity, UserAssignedIdentity) is problematic - needs integration test") + def 'should get pool managed identity resource ID with auto'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + def resourceId = '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/identity1' + + // Create groovy mocks for the identity objects + def identity1 = GroovyMock(UserAssignedIdentity) { + getResourceId() >> resourceId + getClientId() >> 'client-123' + } + def identity2 = GroovyMock(UserAssignedIdentity) { + getResourceId() >> '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/identity2' + getClientId() >> 'client-456' + } + + def poolIdentity = GroovyMock(BatchPoolIdentity) { + getUserAssignedIdentities() >> [identity1, identity2] + } + + def pool = GroovyMock(BatchPool) { + getIdentity() >> poolIdentity + } + + when: 'poolIdentityClientId is "auto"' + service.getPool(poolId) >> pool + def result = service.getPoolManagedIdentityResourceId(poolId, 'auto') + + then: + result == resourceId // Should return first identity + + when: 'poolIdentityClientId is true (backward compatibility)' + service.getPool(poolId) >> pool + result = service.getPoolManagedIdentityResourceId(poolId, true) + + then: + result == resourceId // Should return first identity + } + + @Ignore("Mocking Azure SDK classes (BatchPool, BatchPoolIdentity, UserAssignedIdentity) is problematic - needs integration test") + def 'should get pool managed identity resource ID with specific client ID'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + def targetClientId = 'client-456' + def targetResourceId = '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/identity2' + + // Create groovy mocks for the identity objects + def identity1 = GroovyMock(UserAssignedIdentity) { + getResourceId() >> '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/identity1' + getClientId() >> 'client-123' + } + def identity2 = GroovyMock(UserAssignedIdentity) { + getResourceId() >> targetResourceId + getClientId() >> targetClientId + } + + def poolIdentity = GroovyMock(BatchPoolIdentity) { + getUserAssignedIdentities() >> [identity1, identity2] + } + + def pool = GroovyMock(BatchPool) { + getIdentity() >> poolIdentity + } + + when: 'specific client ID exists' + service.getPool(poolId) >> pool + def result = service.getPoolManagedIdentityResourceId(poolId, targetClientId) + + then: + result == targetResourceId + + when: 'specific client ID does not exist' + service.getPool(poolId) >> pool + result = service.getPoolManagedIdentityResourceId(poolId, 'non-existent-client') + + then: + result == null // Should return null, not fallback + } + + def 'should handle pool with no identity configuration'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + + when: 'pool has no identity' + def pool = GroovyMock(BatchPool) { + getIdentity() >> null + } + service.getPool(poolId) >> pool + def result = service.getPoolManagedIdentityResourceId(poolId, 'auto') + + then: + result == null + + when: 'pool is null' + service.getPool(poolId) >> null + result = service.getPoolManagedIdentityResourceId(poolId, 'auto') + + then: + result == null + } + + def 'should handle pool with empty identities list'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + + def poolIdentity = GroovyMock(BatchPoolIdentity) { + getUserAssignedIdentities() >> [] + } + def pool = GroovyMock(BatchPool) { + getIdentity() >> { + println "getIdentity called, returning poolIdentity: $poolIdentity" + return poolIdentity + } + } + + when: + def result = service.getPoolManagedIdentityResourceId(poolId, 'auto') + then: + 1 * service.getPool(poolId) >> pool + result == null + } + + def 'should handle exceptions when getting pool managed identity'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + + when: + def result = service.getPoolManagedIdentityResourceId(poolId, 'auto') + then: + 1 * service.getPool(poolId) >> { throw new RuntimeException("Test exception") } + result == null // Should return null on error + } + + def 'should handle unsupported poolIdentityClientId type'() { + given: + def exec = createExecutor() + def service = Spy(new AzBatchService(exec)) + def poolId = 'test-pool' + + def poolIdentity = GroovyMock(BatchPoolIdentity) { + getUserAssignedIdentities() >> [GroovyMock(UserAssignedIdentity)] + } + def pool = GroovyMock(BatchPool) { + getIdentity() >> { + println "getIdentity called, returning poolIdentity: $poolIdentity" + return poolIdentity + } + } + + when: 'poolIdentityClientId is an unsupported type' + def result = service.getPoolManagedIdentityResourceId(poolId, 123) // number instead of string + then: + 1 * service.getPool(poolId) >> pool + result == null + } + + def 'should pass managed identity to destFile when available'() { + given: + def CONFIG = new AzConfig([ + batch: [poolIdentityClientId: true], + storage: [accountName: 'myaccount', sasToken: 'test-sas'] + ]) + def exec = createExecutor(CONFIG) + def service = Spy(new AzBatchService(exec)) + def workDir = mockAzPath('az://container/work/dir') + def managedIdentityResourceId = '/subscriptions/123/resourceGroups/rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity' + + when: 'outputFileUrls is called with managed identity' + def result = service.outputFileUrls(Mock(TaskRun) { getWorkDir() >> workDir }, managedIdentityResourceId, 'test-sas') + + then: 'destFile is called with managed identity resource ID for each output file' + 9 * service.destFile(_, workDir, managedIdentityResourceId, 'test-sas') >> GroovyMock(OutputFile) + result.size() == 9 + } + + def 'should pass null to destFile when no managed identity'() { + given: + def CONFIG = new AzConfig([ + storage: [accountName: 'myaccount', sasToken: 'test-sas'] + ]) + def exec = createExecutor(CONFIG) + def service = Spy(new AzBatchService(exec)) + def workDir = mockAzPath('az://container/work/dir') + + when: 'outputFileUrls is called without managed identity' + def result = service.outputFileUrls(Mock(TaskRun) { getWorkDir() >> workDir }, null, 'test-sas') + + then: 'destFile is called with null for managed identity' + 9 * service.destFile(_, workDir, null, 'test-sas') >> GroovyMock(OutputFile) + result.size() == 9 + } + }