diff --git a/.github/workflows/validate-build-scale.yml b/.github/workflows/validate-build-scale.yml
index a56cde80b..239072275 100644
--- a/.github/workflows/validate-build-scale.yml
+++ b/.github/workflows/validate-build-scale.yml
@@ -1,4 +1,4 @@
-name: Functions Scale Tests Azure Storage
+name: Functions Scale Tests
permissions:
contents: read
@@ -18,7 +18,8 @@ env:
AzureWebJobsStorage: UseDevelopmentStorage=true
jobs:
- build:
+ azure-storage:
+ name: Functions Scale Tests - Azure Storage
runs-on: ubuntu-latest
steps:
@@ -79,4 +80,56 @@ jobs:
done
# Run tests
- dotnet test ./test/ScaleTests/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests.csproj -c $config --no-build --verbosity normal
\ No newline at end of file
+ dotnet test ./test/ScaleTests/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests.csproj -c $config --no-build --verbosity normal --filter "FullyQualifiedName!~AzureManaged"
+
+ dts:
+ name: Functions Scale Tests - DTS
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ submodules: true
+
+ - name: Setup .NET
+ uses: actions/setup-dotnet@v3
+ with:
+ global-json-file: global.json
+
+ - name: Set up Node.js (needed for Azurite)
+ uses: actions/setup-node@v3
+ with:
+ node-version: '18.x'
+
+ - name: Install Azurite
+ run: npm install -g azurite
+
+ - name: Start Azurite
+ run: azurite --silent --blobPort 10000 --queuePort 10001 --tablePort 10002 &
+
+ - name: Pull DTS Emulator Docker Image
+ run: docker pull mcr.microsoft.com/dts/dts-emulator:latest
+
+ - name: Start DTS Container
+ run: |
+ docker run -i \
+ -p 8080:8080 \
+ -p 8082:8082 \
+ -d mcr.microsoft.com/dts/dts-emulator:latest
+
+ - name: Wait for DTS to be ready
+ run: |
+ echo "Waiting for DTS to be ready..."
+ sleep 30
+ docker ps
+
+ - name: Restore and Build Scale Tests
+ working-directory: test/ScaleTests
+ run: dotnet build --configuration Release
+
+ - name: Run Azure Managed Scale Tests
+ working-directory: test/ScaleTests
+ env:
+ AzureWebJobsStorage: UseDevelopmentStorage=true
+ DURABLE_TASK_SCHEDULER_CONNECTION_STRING: "Endpoint=http://localhost:8080;Authentication=None"
+ run: dotnet test --configuration Release --no-build --verbosity normal --filter "FullyQualifiedName~AzureManaged"
\ No newline at end of file
diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs
new file mode 100644
index 000000000..fee65a31d
--- /dev/null
+++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs
@@ -0,0 +1,80 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in the project root for license information.
+
+using System;
+using Microsoft.Azure.WebJobs.Host.Scale;
+using Microsoft.DurableTask.AzureManagedBackend;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged
+{
+ ///
+ /// The AzureManaged backend implementation of the scalability provider for Durable Functions.
+ ///
+ public class AzureManagedScalabilityProvider : ScalabilityProvider
+ {
+ private readonly AzureManagedOrchestrationService orchestrationService;
+ private readonly string connectionName;
+ private readonly ILogger logger;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// The instance that provides access to backend service for scaling operations.
+ ///
+ ///
+ /// The logical name of the storage or service connection associated with this provider.
+ ///
+ ///
+ /// The instance used for logging provider activities and diagnostics.
+ ///
+ ///
+ /// Thrown if is .
+ ///
+ public AzureManagedScalabilityProvider(
+ AzureManagedOrchestrationService orchestrationService,
+ string connectionName,
+ ILogger logger)
+ : base("AzureManaged", connectionName)
+ {
+ this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService));
+ this.connectionName = connectionName;
+ this.logger = logger;
+ }
+
+ ///
+ /// Gets the app setting containing the Azure Managed connection string.
+ ///
+ public override string ConnectionName => this.connectionName;
+
+ ///
+ /// This is not used.
+ public override bool TryGetScaleMonitor(
+ string functionId,
+ string functionName,
+ string hubName,
+ string targetConnectionName,
+ out IScaleMonitor scaleMonitor)
+ {
+ // Azure Managed backend does not support the legacy scale monitor infrastructure.
+ // Return false so that the scaling utilities can provide a no-op monitor.
+ scaleMonitor = null;
+ return false;
+ }
+
+ ///
+ public override bool TryGetTargetScaler(
+ string functionId,
+ string functionName,
+ string hubName,
+ string targetConnectionName,
+ out ITargetScaler targetScaler)
+ {
+ // Create a target scaler that uses the orchestration service's metrics endpoint.
+ // All target scalers share the same AzureManagedOrchestrationService in the same task hub.
+ targetScaler = new AzureManagedTargetScaler(this.orchestrationService, functionId, this.logger);
+ return true;
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs
new file mode 100644
index 000000000..c32a6bc48
--- /dev/null
+++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs
@@ -0,0 +1,241 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using Azure.Core;
+using Azure.Identity;
+using Microsoft.Azure.WebJobs.Host.Scale;
+using Microsoft.DurableTask.AzureManagedBackend;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+
+#nullable enable
+namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged
+{
+ ///
+ /// Factory class responsible for creating and managing instances of .
+ ///
+ public class AzureManagedScalabilityProviderFactory : IScalabilityProviderFactory
+ {
+ private const string LoggerName = "Triggers.DurableTask.AzureManaged";
+ private const string ProviderName = "AzureManaged";
+
+ private readonly Dictionary<(string, string?, string?), AzureManagedScalabilityProvider> cachedProviders = new Dictionary<(string, string?, string?), AzureManagedScalabilityProvider>();
+ private readonly IConfiguration configuration;
+ private readonly ILoggerFactory loggerFactory;
+ private readonly ILogger logger;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// The interface used to resolve connection strings and application settings.
+ ///
+ ///
+ /// The used to create loggers for diagnostics.
+ ///
+ ///
+ /// Thrown if any required argument is .
+ ///
+ public AzureManagedScalabilityProviderFactory(
+ IConfiguration configuration,
+ ILoggerFactory loggerFactory)
+ {
+ this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
+ this.loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
+ this.logger = this.loggerFactory.CreateLogger(LoggerName);
+
+ this.DefaultConnectionName = "DURABLE_TASK_SCHEDULER_CONNECTION_STRING";
+ }
+
+ ///
+ /// Gets the logical name of this scalability provider type.
+ ///
+ public virtual string Name => ProviderName;
+
+ ///
+ /// Gets the default connection name configured for this factory.
+ ///
+ public string DefaultConnectionName { get; }
+
+ ///
+ /// Creates or retrieves an instance based on the provided pre-deserialized metadata.
+ ///
+ /// The pre-deserialized Durable Task metadata.
+ /// Trigger metadata used to access Properties like token credentials.
+ ///
+ /// An instance configured using
+ /// the specified metadata and resolved connection information.
+ ///
+ ///
+ /// Thrown if no valid connection string could be resolved for the given connection name.
+ ///
+ public ScalabilityProvider GetScalabilityProvider(DurableTaskMetadata metadata, TriggerMetadata? triggerMetadata)
+ {
+ if (metadata != null)
+ {
+ this.ValidateMetadata(metadata);
+ }
+
+ // Get connection name from metadata, fallback to default
+ string? rawConnectionName = TriggerMetadataExtensions.ResolveConnectionName(metadata?.StorageProvider);
+ string connectionName = rawConnectionName ?? this.DefaultConnectionName;
+ this.logger.LogInformation("Using connection name '{ConnectionName}'", connectionName);
+
+ // Look up connection string from configuration
+ string? connectionString =
+ this.configuration.GetConnectionString(connectionName) ??
+ this.configuration[connectionName] ??
+ Environment.GetEnvironmentVariable(connectionName);
+
+ if (string.IsNullOrEmpty(connectionString))
+ {
+ throw new InvalidOperationException(
+ $"No valid connection string found for '{connectionName}'. " +
+ $"Please ensure it is defined in app settings, connection strings, or environment variables.");
+ }
+
+ AzureManagedConnectionString azureManagedConnectionString = new AzureManagedConnectionString(connectionString);
+
+ // Extract task hub name from metadata
+ string? taskHubName = metadata?.TaskHubName ?? azureManagedConnectionString.TaskHubName;
+
+ // Include client ID in cache key to handle managed identity changes
+ // Use the original connection name (rawConnectionName or default) for the cache key, not the connection string value
+ (string, string?, string?) cacheKey = (connectionName, taskHubName, azureManagedConnectionString.ClientId);
+
+ this.logger.LogDebug(
+ "Getting durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'...",
+ cacheKey.Item1,
+ cacheKey.Item2 ?? "null",
+ cacheKey.Item3 ?? "null");
+
+ lock (this.cachedProviders)
+ {
+ // If a provider has already been created for this connection name, task hub, and client ID, return it.
+ if (this.cachedProviders.TryGetValue(cacheKey, out AzureManagedScalabilityProvider? cachedProvider))
+ {
+ this.logger.LogDebug(
+ "Returning cached durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'",
+ cacheKey.Item1,
+ cacheKey.Item2,
+ cacheKey.Item3 ?? "null");
+ return cachedProvider;
+ }
+
+ // Create options from the connection string.
+ // For runtime-driven scaling, token credentials are loaded directly from the host.
+ AzureManagedOrchestrationServiceOptions options =
+ AzureManagedOrchestrationServiceOptions.FromConnectionString(connectionString);
+
+ // If triggerMetadata is provided (from functions Scale Controller), try to get token credential from it.
+ if (triggerMetadata != null && triggerMetadata.Properties != null &&
+ triggerMetadata.Properties.TryGetValue("GetAzureManagedTokenCredential", out object? tokenCredentialFunc))
+ {
+ if (tokenCredentialFunc is Func getTokenCredential)
+ {
+ try
+ {
+ TokenCredential tokenCredential = getTokenCredential(connectionName);
+
+ if (tokenCredential == null)
+ {
+ this.logger.LogWarning(
+ "Token credential retrieved from trigger metadata is null for connection '{Connection}'.",
+ connectionName);
+ }
+ else
+ {
+ // Override the credential from connection string
+ options.TokenCredential = tokenCredential;
+ this.logger.LogInformation("Retrieved token credential from trigger metadata for connection '{Connection}'", connectionName);
+ }
+ }
+ catch (OperationCanceledException ex)
+ {
+ // Expected scenario when the operation is cancelled;
+ // log and fall back to the connection string credential.
+ this.logger.LogWarning(
+ ex,
+ "Getting token credential from trigger metadata was canceled for connection '{Connection}'",
+ connectionName);
+ }
+ catch (AuthenticationFailedException ex)
+ {
+ // Authentication failures are expected in some environments;
+ // log and fall back to the connection string credential.
+ this.logger.LogWarning(
+ ex,
+ "Authentication failed while getting token credential from trigger metadata for connection '{Connection}'",
+ connectionName);
+ }
+ catch (Exception ex)
+ {
+ // Unexpected exception types. Fall back to use connection string.
+ this.logger.LogWarning(
+ ex,
+ "Unexpected error while getting token credential from trigger metadata for connection '{Connection}'",
+ connectionName);
+ }
+ }
+ else
+ {
+ this.logger.LogWarning(
+ "Token credential function pointer in trigger metadata is not of expected type for connection '{Connection}'",
+ connectionName);
+ }
+ }
+ else
+ {
+ this.logger.LogInformation(
+ "No trigger metadata provided or trigger metadata does not contain 'GetAzureManagedTokenCredential', " +
+ "using the token credential built from connection string for connection '{Connection}'.", connectionName);
+ }
+
+ // Set task hub name if configured
+ if (!string.IsNullOrEmpty(taskHubName))
+ {
+ options.TaskHubName = taskHubName;
+ }
+
+ int defaultConcurrency = 10;
+
+ // Extract max concurrent values from trigger metadata.
+ // If nothing is provided from TriggerMetadata, we use default value which is 10.
+ options.MaxConcurrentOrchestrationWorkItems = metadata?.MaxConcurrentOrchestratorFunctions ?? defaultConcurrency;
+ options.MaxConcurrentActivityWorkItems = metadata?.MaxConcurrentActivityFunctions ?? defaultConcurrency;
+
+ this.logger.LogInformation(
+ "Creating durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'...",
+ cacheKey.Item1,
+ cacheKey.Item2,
+ cacheKey.Item3 ?? "null");
+
+ AzureManagedOrchestrationService service = new AzureManagedOrchestrationService(options, this.loggerFactory);
+ AzureManagedScalabilityProvider provider = new AzureManagedScalabilityProvider(service, connectionName, this.logger);
+
+ // Set max concurrent values from trigger metadata for provider.
+ // If nothing is provided, use default value which is 10.
+ provider.MaxConcurrentTaskOrchestrationWorkItems = metadata?.MaxConcurrentOrchestratorFunctions ?? defaultConcurrency;
+ provider.MaxConcurrentTaskActivityWorkItems = metadata?.MaxConcurrentActivityFunctions ?? defaultConcurrency;
+
+ this.cachedProviders.Add(cacheKey, provider);
+ return provider;
+ }
+ }
+
+ private void ValidateMetadata(DurableTaskMetadata metadata)
+ {
+ if (metadata.MaxConcurrentOrchestratorFunctions.HasValue && metadata.MaxConcurrentOrchestratorFunctions.Value <= 0)
+ {
+ throw new InvalidOperationException($"{nameof(metadata.MaxConcurrentOrchestratorFunctions)} must be a positive integer.");
+ }
+
+ if (metadata.MaxConcurrentActivityFunctions.HasValue && metadata.MaxConcurrentActivityFunctions.Value <= 0)
+ {
+ throw new InvalidOperationException($"{nameof(metadata.MaxConcurrentActivityFunctions)} must be a positive integer.");
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs
new file mode 100644
index 000000000..7ce9a4d63
--- /dev/null
+++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs
@@ -0,0 +1,68 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in the project root for license information.
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Scale;
+using Microsoft.DurableTask.AzureManagedBackend;
+using Microsoft.DurableTask.AzureManagedBackend.Metrics;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged
+{
+ internal class AzureManagedTargetScaler : ITargetScaler
+ {
+ private readonly AzureManagedOrchestrationService service;
+ private readonly TargetScalerDescriptor descriptor;
+ private readonly ILogger logger;
+
+ public AzureManagedTargetScaler(AzureManagedOrchestrationService service, string functionId, ILogger logger)
+ {
+ this.service = service;
+ this.descriptor = new TargetScalerDescriptor(functionId);
+ this.logger = logger;
+ }
+
+ public TargetScalerDescriptor TargetScalerDescriptor => this.descriptor;
+
+ public async Task GetScaleResultAsync(TargetScalerContext context)
+ {
+ TaskHubMetrics metrics = await this.service.GetTaskHubMetricsAsync(default);
+ if (metrics is null)
+ {
+ this.logger?.LogWarning("Task hub metrics returned null from Azure Managed backend. This may indicate the DTS emulator is being used which may not support metrics. Returning 0 worker count.");
+ return new TargetScalerResult { TargetWorkerCount = 0 };
+ }
+
+ static int GetTargetWorkerCount(WorkItemMetrics workItemMetrics, int workItemCapacity)
+ {
+ if (workItemCapacity == 0)
+ {
+ return 0;
+ }
+
+ int totalWorkItemCount = workItemMetrics.PendingCount + workItemMetrics.ActiveCount;
+ return (int)Math.Ceiling((double)totalWorkItemCount / workItemCapacity);
+ }
+
+ int targetForOrchestratorWorkItems = GetTargetWorkerCount(
+ workItemMetrics: metrics.OrchestratorWorkItems,
+ workItemCapacity: this.service.MaxConcurrentTaskOrchestrationWorkItems);
+
+ int targetForActivityWorkItems = GetTargetWorkerCount(
+ workItemMetrics: metrics.ActivityWorkItems,
+ workItemCapacity: this.service.MaxConcurrentTaskActivityWorkItems);
+
+ int targetForEntityWorkItems = GetTargetWorkerCount(
+ workItemMetrics: metrics.EntityWorkItems,
+ workItemCapacity: this.service.MaxConcurrentEntityWorkItems);
+
+ // Scale out to the maximum of the above target values.
+ int maxTarget = Math.Max(
+ targetForOrchestratorWorkItems,
+ Math.Max(targetForActivityWorkItems, targetForEntityWorkItems));
+
+ return new TargetScalerResult { TargetWorkerCount = maxTarget };
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs
index 4a94dd834..1150194a5 100644
--- a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs
+++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs
@@ -3,6 +3,7 @@
using System;
using System.Linq;
+using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged;
using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureStorage;
using Microsoft.Azure.WebJobs.Host.Scale;
using Microsoft.Extensions.Configuration;
@@ -43,6 +44,8 @@ public static Microsoft.Azure.WebJobs.IWebJobsBuilder AddDurableTask(this IWebJo
serviceCollection.AddSingleton();
+ serviceCollection.AddSingleton();
+
return builder;
}
diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj
index bb7f4f201..6b11b1d72 100644
--- a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj
+++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj
@@ -51,6 +51,7 @@
+
diff --git a/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs b/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs
new file mode 100644
index 000000000..e2b556ca0
--- /dev/null
+++ b/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs
@@ -0,0 +1,183 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged;
+using Microsoft.Azure.WebJobs.Host.Scale;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+using Newtonsoft.Json.Linq;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests
+{
+ public class AzureManagedScalabilityProviderFactoryTests
+ {
+ private readonly ITestOutputHelper output;
+ private readonly ILoggerFactory loggerFactory;
+ private readonly IConfiguration configuration;
+
+ public AzureManagedScalabilityProviderFactoryTests(ITestOutputHelper output)
+ {
+ this.output = output;
+ this.loggerFactory = new LoggerFactory();
+ this.loggerFactory.AddProvider(new TestLoggerProvider(output));
+
+ this.configuration = new ConfigurationBuilder()
+ .AddInMemoryCollection(new Dictionary
+ {
+ { "v3-dtsConnectionMI", "Endpoint=https://test.westus.durabletask.io;Authentication=DefaultAzure" },
+ { "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "Endpoint=https://default.westus.durabletask.io;Authentication=DefaultAzure" },
+ })
+ .Build();
+ }
+
+ ///
+ /// Validates that the factory can be instantiated with valid parameters,
+ /// reports the correct provider name, and exposes the expected default connection name.
+ ///
+ [Fact]
+ public void Constructor_ValidParameters_CreatesInstance()
+ {
+ var factory = new AzureManagedScalabilityProviderFactory(
+ this.configuration,
+ this.loggerFactory);
+
+ Assert.NotNull(factory);
+ Assert.Equal("AzureManaged", factory.Name);
+ Assert.Equal("DURABLE_TASK_SCHEDULER_CONNECTION_STRING", factory.DefaultConnectionName);
+ }
+
+ ///
+ /// KEY SCENARIO: Scale Controller sends trigger metadata with storageProvider.type = "azureManaged".
+ /// Validates that the factory returns an AzureManagedScalabilityProvider with the correct
+ /// connection name and concurrency limits taken from the trigger metadata.
+ ///
+ [Fact]
+ public void GetScalabilityProvider_WithTriggerMetadata_ReturnsAzureManagedProvider()
+ {
+ var factory = new AzureManagedScalabilityProviderFactory(
+ this.configuration,
+ this.loggerFactory);
+
+ var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 15, 25, "v3-dtsConnectionMI", "azureManaged");
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ var provider = factory.GetScalabilityProvider(metadata, triggerMetadata);
+
+ Assert.NotNull(provider);
+ Assert.IsType(provider);
+ var azureProvider = (AzureManagedScalabilityProvider)provider;
+
+ // The provider connection name should be same as what we set at metadata.
+ Assert.Equal("v3-dtsConnectionMI", azureProvider.ConnectionName);
+ Assert.Equal(15, azureProvider.MaxConcurrentTaskOrchestrationWorkItems);
+ Assert.Equal(25, azureProvider.MaxConcurrentTaskActivityWorkItems);
+ }
+
+ ///
+ /// Validates that when trigger metadata does not include a connectionName in storageProvider,
+ /// the factory falls back to the default connection name DURABLE_TASK_SCHEDULER_CONNECTION_STRING.
+ ///
+ [Fact]
+ public void GetScalabilityProvider_WithNoConnectionNameInMetadata_UsesDefaultConnectionName()
+ {
+ var factory = new AzureManagedScalabilityProviderFactory(
+ this.configuration,
+ this.loggerFactory);
+
+ var jobj = new JObject
+ {
+ { "functionName", "TestFunction" },
+ { "taskHubName", "testHub" },
+ { "storageProvider", new JObject { { "type", "azureManaged" } } },
+ };
+ var triggerMetadata = new TriggerMetadata(jobj);
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ var provider = factory.GetScalabilityProvider(metadata, triggerMetadata);
+
+ Assert.NotNull(provider);
+ Assert.Equal("DURABLE_TASK_SCHEDULER_CONNECTION_STRING", provider.ConnectionName);
+ }
+
+ ///
+ /// Validates that when the specified connection string is absent from configuration,
+ /// the factory throws an InvalidOperationException rather than silently continuing.
+ ///
+ [Fact]
+ public void GetScalabilityProvider_MissingConnectionString_ThrowsInvalidOperationException()
+ {
+ var emptyConfig = new ConfigurationBuilder().Build();
+ var factory = new AzureManagedScalabilityProviderFactory(
+ emptyConfig,
+ this.loggerFactory);
+
+ var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "MISSING_CONNECTION", "azureManaged");
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ Assert.Throws(() => factory.GetScalabilityProvider(metadata, triggerMetadata));
+ }
+
+ ///
+ /// Validates that calling GetScalabilityProvider twice with the same connection name and task hub
+ /// returns the same cached provider instance, avoiding redundant connections per scale decision.
+ ///
+ [Fact]
+ public void GetScalabilityProvider_SameParameters_ReturnsCachedInstance()
+ {
+ var factory = new AzureManagedScalabilityProviderFactory(
+ this.configuration,
+ this.loggerFactory);
+
+ var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "v3-dtsConnectionMI", "azureManaged");
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ var provider1 = factory.GetScalabilityProvider(metadata, triggerMetadata);
+ var provider2 = factory.GetScalabilityProvider(metadata, triggerMetadata);
+
+ Assert.Same(provider1, provider2);
+ }
+
+ ///
+ /// KEY SCENARIO: When multiple provider factories are registered and trigger metadata specifies
+ /// storageProvider.type = "azureManaged", DurableTaskScaleExtension.GetScalabilityProviderFactory
+ /// must select the AzureManagedScalabilityProviderFactory.
+ ///
+ [Fact]
+ public void GetScalabilityProviderFactory_WhenMetadataTypeIsAzureManaged_SelectsAzureManagedFactory()
+ {
+ var azureManagedFactory = new AzureManagedScalabilityProviderFactory(
+ this.configuration,
+ this.loggerFactory);
+
+ // Create a list of mock Azure Storage provider factories to simulate multiple provider factories.
+ IScalabilityProviderFactory[] factories = new IScalabilityProviderFactory[]
+ {
+ new StubAzureStorageFactory(),
+ azureManagedFactory,
+ };
+
+ var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "azureManaged");
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ var logger = this.loggerFactory.CreateLogger("test");
+ var selectedFactory = DurableTaskScaleExtension.GetScalabilityProviderFactory(metadata, logger, factories);
+
+ Assert.IsType(selectedFactory);
+ }
+
+ // Stub used to test factory selection without a real storage emulator.
+ private class StubAzureStorageFactory : IScalabilityProviderFactory
+ {
+ public string Name => "AzureStorage";
+
+ public string DefaultConnectionName => "AzureWebJobsStorage";
+
+ public ScalabilityProvider GetScalabilityProvider(DurableTaskMetadata metadata, TriggerMetadata triggerMetadata)
+ => throw new NotImplementedException();
+ }
+ }
+}
diff --git a/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs b/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs
new file mode 100644
index 000000000..89443bf86
--- /dev/null
+++ b/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs
@@ -0,0 +1,133 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the MIT License. See LICENSE in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+using DurableTask.Core;
+using DurableTask.Core.History;
+using DurableTask.Core.Query;
+using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged;
+using Microsoft.Azure.WebJobs.Host.Scale;
+using Microsoft.DurableTask.AzureManagedBackend;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests
+{
+ ///
+ /// Validates that the target-based autoscaling mechanism produces correct worker counts
+ /// based on pending/active work item metrics from the Azure Managed backend.
+ ///
+ public class AzureManagedTargetScalerTests
+ {
+ private readonly ITestOutputHelper output;
+
+ public AzureManagedTargetScalerTests(ITestOutputHelper output)
+ {
+ this.output = output;
+ }
+
+ ///
+ /// Target scaler calculates correct worker count based on pending orchestrations.
+ /// Validates that with 20 pending orchestrations and MaxConcurrentOrchestrators=2,
+ /// the scaler returns 10 workers (20/2 = 10).
+ ///
+ [Fact]
+ public async Task TargetBasedScaling_WithPendingOrchestrations_ReturnsExpectedWorkerCount()
+ {
+ var taskHubName = "default";
+ var connectionString = TestHelpers.GetAzureManagedConnectionString();
+ var options = AzureManagedOrchestrationServiceOptions.FromConnectionString(connectionString);
+ options.TaskHubName = taskHubName;
+ options.MaxConcurrentOrchestrationWorkItems = 2;
+ options.MaxConcurrentActivityWorkItems = 2;
+
+ this.output.WriteLine($"Creating connection to the test DTS TaskHub: {taskHubName}");
+
+ var loggerFactory = new LoggerFactory();
+ using var service = new AzureManagedOrchestrationService(options, loggerFactory);
+
+ var status = new List
+ {
+ OrchestrationStatus.Pending,
+ OrchestrationStatus.Running,
+ OrchestrationStatus.Suspended,
+ };
+
+ var query = new OrchestrationQuery { RuntimeStatus = status };
+ var result = await service.GetOrchestrationWithQueryAsync(query, CancellationToken.None);
+
+ int existingCount = result.OrchestrationState?.Count ?? 0;
+ int orchestrationsToCreate = Math.Max(0, 20 - existingCount);
+
+ this.output.WriteLine($"Found {existingCount} existing orchestrations. Creating {orchestrationsToCreate} new ones.");
+
+ for (int i = 0; i < orchestrationsToCreate; i++)
+ {
+ var instance = new OrchestrationInstance
+ {
+ InstanceId = $"TestOrchestration_{Guid.NewGuid():N}",
+ ExecutionId = Guid.NewGuid().ToString(),
+ };
+
+ await service.CreateTaskOrchestrationAsync(
+ new TaskMessage
+ {
+ OrchestrationInstance = instance,
+ Event = new ExecutionStartedEvent(-1, "TestInput")
+ {
+ OrchestrationInstance = instance,
+ Name = "TestOrchestration",
+ Version = "1.0",
+ Input = "TestInput",
+ },
+ });
+ }
+
+ await Task.Delay(2000);
+
+ var configuration = new ConfigurationBuilder()
+ .AddInMemoryCollection(new Dictionary
+ {
+ { "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", connectionString },
+ })
+ .Build();
+
+ var factory = new AzureManagedScalabilityProviderFactory(
+ configuration,
+ loggerFactory);
+
+ var triggerMetadata = TestHelpers.CreateTriggerMetadata(taskHubName, 2, 2, "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "azureManaged");
+ var metadata = triggerMetadata.ExtractDurableTaskMetadata();
+
+ var provider = factory.GetScalabilityProvider(metadata, triggerMetadata);
+ Assert.True(provider is AzureManagedScalabilityProvider, "Expected AzureManagedScalabilityProvider from factory.");
+
+ bool targetScalerCreated = provider.TryGetTargetScaler(
+ "functionId",
+ "TestFunction",
+ taskHubName,
+ "DURABLE_TASK_SCHEDULER_CONNECTION_STRING",
+ out ITargetScaler targetScaler);
+
+ Assert.True(targetScalerCreated);
+ Assert.NotNull(targetScaler);
+ Assert.IsType(targetScaler);
+
+ var verifyResult = await service.GetOrchestrationWithQueryAsync(new OrchestrationQuery { RuntimeStatus = status }, CancellationToken.None);
+ this.output.WriteLine($"Found {verifyResult.OrchestrationState?.Count ?? 0} orchestrations via query");
+
+ await Task.Delay(3000);
+
+ TargetScalerResult scalerResult = await targetScaler.GetScaleResultAsync(new TargetScalerContext());
+
+ Assert.NotNull(scalerResult);
+ this.output.WriteLine($"Target worker count: {scalerResult.TargetWorkerCount}");
+ Assert.Equal(10, scalerResult.TargetWorkerCount);
+ }
+ }
+}
diff --git a/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs b/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs
index 1e8aa38fd..cf3b3d570 100644
--- a/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs
+++ b/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs
@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Linq;
+using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged;
using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureStorage;
using Microsoft.Azure.WebJobs.Host.Config;
using Microsoft.Extensions.Configuration;
@@ -20,10 +21,8 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests
public class DurableTaskScaleConfigurationExtensionsTests
{
///
- /// Scenario: Core service registration in DI container.
/// Validates that AddDurableTask() registers IStorageServiceClientProviderFactory.
/// Validates that AddDurableTask() registers IScalabilityProviderFactory implementations.
- /// Tests the foundational setup required by Scale Controller integration.
/// Ensures Scale Controller can resolve storage clients and scalability providers.
///
[Fact]
@@ -52,6 +51,7 @@ public void AddDurableTask_RegistersRequiredServices()
var scalabilityProviderFactories = serviceProvider.GetServices().ToList();
Assert.NotEmpty(scalabilityProviderFactories);
Assert.Contains(scalabilityProviderFactories, f => f is AzureStorageScalabilityProviderFactory);
+ Assert.Contains(scalabilityProviderFactories, f => f is AzureManagedScalabilityProviderFactory);
}
}
diff --git a/test/ScaleTests/TestHelpers.cs b/test/ScaleTests/TestHelpers.cs
index 773546c6f..b2fc27a76 100644
--- a/test/ScaleTests/TestHelpers.cs
+++ b/test/ScaleTests/TestHelpers.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License. See LICENSE in the project root for license information.
using System;
-using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale;
using Microsoft.Azure.WebJobs.Host.Scale;
using Newtonsoft.Json.Linq;
@@ -30,9 +29,6 @@ public static string GetSqlConnectionString()
return sqlConnectionString;
}
- // If no environment variable is set, throw an exception to ensure tests verify that
- // the package correctly reads connection strings from configuration/environment variables.
- // This prevents tests from silently using a hardcoded default that doesn't match the actual environment.
throw new InvalidOperationException(
"SQL connection string not found in environment variables.");
}
@@ -46,9 +42,6 @@ public static string GetAzureManagedConnectionString()
return connectionString;
}
- // If no environment variable is set, throw an exception to ensure tests verify that
- // the package correctly reads connection strings from configuration/environment variables.
- // This prevents tests from silently using a hardcoded default that doesn't match the actual environment.
throw new InvalidOperationException(
"Azure Managed connection string not found in environment variables.");
}