diff --git a/.github/workflows/validate-build-scale.yml b/.github/workflows/validate-build-scale.yml index a56cde80b..239072275 100644 --- a/.github/workflows/validate-build-scale.yml +++ b/.github/workflows/validate-build-scale.yml @@ -1,4 +1,4 @@ -name: Functions Scale Tests Azure Storage +name: Functions Scale Tests permissions: contents: read @@ -18,7 +18,8 @@ env: AzureWebJobsStorage: UseDevelopmentStorage=true jobs: - build: + azure-storage: + name: Functions Scale Tests - Azure Storage runs-on: ubuntu-latest steps: @@ -79,4 +80,56 @@ jobs: done # Run tests - dotnet test ./test/ScaleTests/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests.csproj -c $config --no-build --verbosity normal \ No newline at end of file + dotnet test ./test/ScaleTests/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests.csproj -c $config --no-build --verbosity normal --filter "FullyQualifiedName!~AzureManaged" + + dts: + name: Functions Scale Tests - DTS + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + submodules: true + + - name: Setup .NET + uses: actions/setup-dotnet@v3 + with: + global-json-file: global.json + + - name: Set up Node.js (needed for Azurite) + uses: actions/setup-node@v3 + with: + node-version: '18.x' + + - name: Install Azurite + run: npm install -g azurite + + - name: Start Azurite + run: azurite --silent --blobPort 10000 --queuePort 10001 --tablePort 10002 & + + - name: Pull DTS Emulator Docker Image + run: docker pull mcr.microsoft.com/dts/dts-emulator:latest + + - name: Start DTS Container + run: | + docker run -i \ + -p 8080:8080 \ + -p 8082:8082 \ + -d mcr.microsoft.com/dts/dts-emulator:latest + + - name: Wait for DTS to be ready + run: | + echo "Waiting for DTS to be ready..." + sleep 30 + docker ps + + - name: Restore and Build Scale Tests + working-directory: test/ScaleTests + run: dotnet build --configuration Release + + - name: Run Azure Managed Scale Tests + working-directory: test/ScaleTests + env: + AzureWebJobsStorage: UseDevelopmentStorage=true + DURABLE_TASK_SCHEDULER_CONNECTION_STRING: "Endpoint=http://localhost:8080;Authentication=None" + run: dotnet test --configuration Release --no-build --verbosity normal --filter "FullyQualifiedName~AzureManaged" \ No newline at end of file diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs new file mode 100644 index 000000000..fee65a31d --- /dev/null +++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProvider.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using System; +using Microsoft.Azure.WebJobs.Host.Scale; +using Microsoft.DurableTask.AzureManagedBackend; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged +{ + /// + /// The AzureManaged backend implementation of the scalability provider for Durable Functions. + /// + public class AzureManagedScalabilityProvider : ScalabilityProvider + { + private readonly AzureManagedOrchestrationService orchestrationService; + private readonly string connectionName; + private readonly ILogger logger; + + /// + /// Initializes a new instance of the class. + /// + /// + /// The instance that provides access to backend service for scaling operations. + /// + /// + /// The logical name of the storage or service connection associated with this provider. + /// + /// + /// The instance used for logging provider activities and diagnostics. + /// + /// + /// Thrown if is . + /// + public AzureManagedScalabilityProvider( + AzureManagedOrchestrationService orchestrationService, + string connectionName, + ILogger logger) + : base("AzureManaged", connectionName) + { + this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); + this.connectionName = connectionName; + this.logger = logger; + } + + /// + /// Gets the app setting containing the Azure Managed connection string. + /// + public override string ConnectionName => this.connectionName; + + /// + /// This is not used. + public override bool TryGetScaleMonitor( + string functionId, + string functionName, + string hubName, + string targetConnectionName, + out IScaleMonitor scaleMonitor) + { + // Azure Managed backend does not support the legacy scale monitor infrastructure. + // Return false so that the scaling utilities can provide a no-op monitor. + scaleMonitor = null; + return false; + } + + /// + public override bool TryGetTargetScaler( + string functionId, + string functionName, + string hubName, + string targetConnectionName, + out ITargetScaler targetScaler) + { + // Create a target scaler that uses the orchestration service's metrics endpoint. + // All target scalers share the same AzureManagedOrchestrationService in the same task hub. + targetScaler = new AzureManagedTargetScaler(this.orchestrationService, functionId, this.logger); + return true; + } + } +} diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs new file mode 100644 index 000000000..c32a6bc48 --- /dev/null +++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedScalabilityProviderFactory.cs @@ -0,0 +1,241 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using System; +using System.Collections.Generic; +using Azure.Core; +using Azure.Identity; +using Microsoft.Azure.WebJobs.Host.Scale; +using Microsoft.DurableTask.AzureManagedBackend; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; + +#nullable enable +namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged +{ + /// + /// Factory class responsible for creating and managing instances of . + /// + public class AzureManagedScalabilityProviderFactory : IScalabilityProviderFactory + { + private const string LoggerName = "Triggers.DurableTask.AzureManaged"; + private const string ProviderName = "AzureManaged"; + + private readonly Dictionary<(string, string?, string?), AzureManagedScalabilityProvider> cachedProviders = new Dictionary<(string, string?, string?), AzureManagedScalabilityProvider>(); + private readonly IConfiguration configuration; + private readonly ILoggerFactory loggerFactory; + private readonly ILogger logger; + + /// + /// Initializes a new instance of the class. + /// + /// + /// The interface used to resolve connection strings and application settings. + /// + /// + /// The used to create loggers for diagnostics. + /// + /// + /// Thrown if any required argument is . + /// + public AzureManagedScalabilityProviderFactory( + IConfiguration configuration, + ILoggerFactory loggerFactory) + { + this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + this.loggerFactory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory)); + this.logger = this.loggerFactory.CreateLogger(LoggerName); + + this.DefaultConnectionName = "DURABLE_TASK_SCHEDULER_CONNECTION_STRING"; + } + + /// + /// Gets the logical name of this scalability provider type. + /// + public virtual string Name => ProviderName; + + /// + /// Gets the default connection name configured for this factory. + /// + public string DefaultConnectionName { get; } + + /// + /// Creates or retrieves an instance based on the provided pre-deserialized metadata. + /// + /// The pre-deserialized Durable Task metadata. + /// Trigger metadata used to access Properties like token credentials. + /// + /// An instance configured using + /// the specified metadata and resolved connection information. + /// + /// + /// Thrown if no valid connection string could be resolved for the given connection name. + /// + public ScalabilityProvider GetScalabilityProvider(DurableTaskMetadata metadata, TriggerMetadata? triggerMetadata) + { + if (metadata != null) + { + this.ValidateMetadata(metadata); + } + + // Get connection name from metadata, fallback to default + string? rawConnectionName = TriggerMetadataExtensions.ResolveConnectionName(metadata?.StorageProvider); + string connectionName = rawConnectionName ?? this.DefaultConnectionName; + this.logger.LogInformation("Using connection name '{ConnectionName}'", connectionName); + + // Look up connection string from configuration + string? connectionString = + this.configuration.GetConnectionString(connectionName) ?? + this.configuration[connectionName] ?? + Environment.GetEnvironmentVariable(connectionName); + + if (string.IsNullOrEmpty(connectionString)) + { + throw new InvalidOperationException( + $"No valid connection string found for '{connectionName}'. " + + $"Please ensure it is defined in app settings, connection strings, or environment variables."); + } + + AzureManagedConnectionString azureManagedConnectionString = new AzureManagedConnectionString(connectionString); + + // Extract task hub name from metadata + string? taskHubName = metadata?.TaskHubName ?? azureManagedConnectionString.TaskHubName; + + // Include client ID in cache key to handle managed identity changes + // Use the original connection name (rawConnectionName or default) for the cache key, not the connection string value + (string, string?, string?) cacheKey = (connectionName, taskHubName, azureManagedConnectionString.ClientId); + + this.logger.LogDebug( + "Getting durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'...", + cacheKey.Item1, + cacheKey.Item2 ?? "null", + cacheKey.Item3 ?? "null"); + + lock (this.cachedProviders) + { + // If a provider has already been created for this connection name, task hub, and client ID, return it. + if (this.cachedProviders.TryGetValue(cacheKey, out AzureManagedScalabilityProvider? cachedProvider)) + { + this.logger.LogDebug( + "Returning cached durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'", + cacheKey.Item1, + cacheKey.Item2, + cacheKey.Item3 ?? "null"); + return cachedProvider; + } + + // Create options from the connection string. + // For runtime-driven scaling, token credentials are loaded directly from the host. + AzureManagedOrchestrationServiceOptions options = + AzureManagedOrchestrationServiceOptions.FromConnectionString(connectionString); + + // If triggerMetadata is provided (from functions Scale Controller), try to get token credential from it. + if (triggerMetadata != null && triggerMetadata.Properties != null && + triggerMetadata.Properties.TryGetValue("GetAzureManagedTokenCredential", out object? tokenCredentialFunc)) + { + if (tokenCredentialFunc is Func getTokenCredential) + { + try + { + TokenCredential tokenCredential = getTokenCredential(connectionName); + + if (tokenCredential == null) + { + this.logger.LogWarning( + "Token credential retrieved from trigger metadata is null for connection '{Connection}'.", + connectionName); + } + else + { + // Override the credential from connection string + options.TokenCredential = tokenCredential; + this.logger.LogInformation("Retrieved token credential from trigger metadata for connection '{Connection}'", connectionName); + } + } + catch (OperationCanceledException ex) + { + // Expected scenario when the operation is cancelled; + // log and fall back to the connection string credential. + this.logger.LogWarning( + ex, + "Getting token credential from trigger metadata was canceled for connection '{Connection}'", + connectionName); + } + catch (AuthenticationFailedException ex) + { + // Authentication failures are expected in some environments; + // log and fall back to the connection string credential. + this.logger.LogWarning( + ex, + "Authentication failed while getting token credential from trigger metadata for connection '{Connection}'", + connectionName); + } + catch (Exception ex) + { + // Unexpected exception types. Fall back to use connection string. + this.logger.LogWarning( + ex, + "Unexpected error while getting token credential from trigger metadata for connection '{Connection}'", + connectionName); + } + } + else + { + this.logger.LogWarning( + "Token credential function pointer in trigger metadata is not of expected type for connection '{Connection}'", + connectionName); + } + } + else + { + this.logger.LogInformation( + "No trigger metadata provided or trigger metadata does not contain 'GetAzureManagedTokenCredential', " + + "using the token credential built from connection string for connection '{Connection}'.", connectionName); + } + + // Set task hub name if configured + if (!string.IsNullOrEmpty(taskHubName)) + { + options.TaskHubName = taskHubName; + } + + int defaultConcurrency = 10; + + // Extract max concurrent values from trigger metadata. + // If nothing is provided from TriggerMetadata, we use default value which is 10. + options.MaxConcurrentOrchestrationWorkItems = metadata?.MaxConcurrentOrchestratorFunctions ?? defaultConcurrency; + options.MaxConcurrentActivityWorkItems = metadata?.MaxConcurrentActivityFunctions ?? defaultConcurrency; + + this.logger.LogInformation( + "Creating durability provider for connection '{Connection}', task hub '{TaskHub}', and client ID '{ClientId}'...", + cacheKey.Item1, + cacheKey.Item2, + cacheKey.Item3 ?? "null"); + + AzureManagedOrchestrationService service = new AzureManagedOrchestrationService(options, this.loggerFactory); + AzureManagedScalabilityProvider provider = new AzureManagedScalabilityProvider(service, connectionName, this.logger); + + // Set max concurrent values from trigger metadata for provider. + // If nothing is provided, use default value which is 10. + provider.MaxConcurrentTaskOrchestrationWorkItems = metadata?.MaxConcurrentOrchestratorFunctions ?? defaultConcurrency; + provider.MaxConcurrentTaskActivityWorkItems = metadata?.MaxConcurrentActivityFunctions ?? defaultConcurrency; + + this.cachedProviders.Add(cacheKey, provider); + return provider; + } + } + + private void ValidateMetadata(DurableTaskMetadata metadata) + { + if (metadata.MaxConcurrentOrchestratorFunctions.HasValue && metadata.MaxConcurrentOrchestratorFunctions.Value <= 0) + { + throw new InvalidOperationException($"{nameof(metadata.MaxConcurrentOrchestratorFunctions)} must be a positive integer."); + } + + if (metadata.MaxConcurrentActivityFunctions.HasValue && metadata.MaxConcurrentActivityFunctions.Value <= 0) + { + throw new InvalidOperationException($"{nameof(metadata.MaxConcurrentActivityFunctions)} must be a positive integer."); + } + } + } +} diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs new file mode 100644 index 000000000..7ce9a4d63 --- /dev/null +++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/AzureManaged/AzureManagedTargetScaler.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Scale; +using Microsoft.DurableTask.AzureManagedBackend; +using Microsoft.DurableTask.AzureManagedBackend.Metrics; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged +{ + internal class AzureManagedTargetScaler : ITargetScaler + { + private readonly AzureManagedOrchestrationService service; + private readonly TargetScalerDescriptor descriptor; + private readonly ILogger logger; + + public AzureManagedTargetScaler(AzureManagedOrchestrationService service, string functionId, ILogger logger) + { + this.service = service; + this.descriptor = new TargetScalerDescriptor(functionId); + this.logger = logger; + } + + public TargetScalerDescriptor TargetScalerDescriptor => this.descriptor; + + public async Task GetScaleResultAsync(TargetScalerContext context) + { + TaskHubMetrics metrics = await this.service.GetTaskHubMetricsAsync(default); + if (metrics is null) + { + this.logger?.LogWarning("Task hub metrics returned null from Azure Managed backend. This may indicate the DTS emulator is being used which may not support metrics. Returning 0 worker count."); + return new TargetScalerResult { TargetWorkerCount = 0 }; + } + + static int GetTargetWorkerCount(WorkItemMetrics workItemMetrics, int workItemCapacity) + { + if (workItemCapacity == 0) + { + return 0; + } + + int totalWorkItemCount = workItemMetrics.PendingCount + workItemMetrics.ActiveCount; + return (int)Math.Ceiling((double)totalWorkItemCount / workItemCapacity); + } + + int targetForOrchestratorWorkItems = GetTargetWorkerCount( + workItemMetrics: metrics.OrchestratorWorkItems, + workItemCapacity: this.service.MaxConcurrentTaskOrchestrationWorkItems); + + int targetForActivityWorkItems = GetTargetWorkerCount( + workItemMetrics: metrics.ActivityWorkItems, + workItemCapacity: this.service.MaxConcurrentTaskActivityWorkItems); + + int targetForEntityWorkItems = GetTargetWorkerCount( + workItemMetrics: metrics.EntityWorkItems, + workItemCapacity: this.service.MaxConcurrentEntityWorkItems); + + // Scale out to the maximum of the above target values. + int maxTarget = Math.Max( + targetForOrchestratorWorkItems, + Math.Max(targetForActivityWorkItems, targetForEntityWorkItems)); + + return new TargetScalerResult { TargetWorkerCount = maxTarget }; + } + } +} diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs index 4a94dd834..1150194a5 100644 --- a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs +++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/DurableTaskScaleConfigurationExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged; using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureStorage; using Microsoft.Azure.WebJobs.Host.Scale; using Microsoft.Extensions.Configuration; @@ -43,6 +44,8 @@ public static Microsoft.Azure.WebJobs.IWebJobsBuilder AddDurableTask(this IWebJo serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + return builder; } diff --git a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj index bb7f4f201..6b11b1d72 100644 --- a/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj +++ b/src/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale/Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.csproj @@ -51,6 +51,7 @@ + diff --git a/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs b/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs new file mode 100644 index 000000000..e2b556ca0 --- /dev/null +++ b/test/ScaleTests/AzureManaged/AzureManagedScalabilityProviderFactoryTests.cs @@ -0,0 +1,183 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged; +using Microsoft.Azure.WebJobs.Host.Scale; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json.Linq; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests +{ + public class AzureManagedScalabilityProviderFactoryTests + { + private readonly ITestOutputHelper output; + private readonly ILoggerFactory loggerFactory; + private readonly IConfiguration configuration; + + public AzureManagedScalabilityProviderFactoryTests(ITestOutputHelper output) + { + this.output = output; + this.loggerFactory = new LoggerFactory(); + this.loggerFactory.AddProvider(new TestLoggerProvider(output)); + + this.configuration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + { "v3-dtsConnectionMI", "Endpoint=https://test.westus.durabletask.io;Authentication=DefaultAzure" }, + { "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "Endpoint=https://default.westus.durabletask.io;Authentication=DefaultAzure" }, + }) + .Build(); + } + + /// + /// Validates that the factory can be instantiated with valid parameters, + /// reports the correct provider name, and exposes the expected default connection name. + /// + [Fact] + public void Constructor_ValidParameters_CreatesInstance() + { + var factory = new AzureManagedScalabilityProviderFactory( + this.configuration, + this.loggerFactory); + + Assert.NotNull(factory); + Assert.Equal("AzureManaged", factory.Name); + Assert.Equal("DURABLE_TASK_SCHEDULER_CONNECTION_STRING", factory.DefaultConnectionName); + } + + /// + /// KEY SCENARIO: Scale Controller sends trigger metadata with storageProvider.type = "azureManaged". + /// Validates that the factory returns an AzureManagedScalabilityProvider with the correct + /// connection name and concurrency limits taken from the trigger metadata. + /// + [Fact] + public void GetScalabilityProvider_WithTriggerMetadata_ReturnsAzureManagedProvider() + { + var factory = new AzureManagedScalabilityProviderFactory( + this.configuration, + this.loggerFactory); + + var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 15, 25, "v3-dtsConnectionMI", "azureManaged"); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + var provider = factory.GetScalabilityProvider(metadata, triggerMetadata); + + Assert.NotNull(provider); + Assert.IsType(provider); + var azureProvider = (AzureManagedScalabilityProvider)provider; + + // The provider connection name should be same as what we set at metadata. + Assert.Equal("v3-dtsConnectionMI", azureProvider.ConnectionName); + Assert.Equal(15, azureProvider.MaxConcurrentTaskOrchestrationWorkItems); + Assert.Equal(25, azureProvider.MaxConcurrentTaskActivityWorkItems); + } + + /// + /// Validates that when trigger metadata does not include a connectionName in storageProvider, + /// the factory falls back to the default connection name DURABLE_TASK_SCHEDULER_CONNECTION_STRING. + /// + [Fact] + public void GetScalabilityProvider_WithNoConnectionNameInMetadata_UsesDefaultConnectionName() + { + var factory = new AzureManagedScalabilityProviderFactory( + this.configuration, + this.loggerFactory); + + var jobj = new JObject + { + { "functionName", "TestFunction" }, + { "taskHubName", "testHub" }, + { "storageProvider", new JObject { { "type", "azureManaged" } } }, + }; + var triggerMetadata = new TriggerMetadata(jobj); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + var provider = factory.GetScalabilityProvider(metadata, triggerMetadata); + + Assert.NotNull(provider); + Assert.Equal("DURABLE_TASK_SCHEDULER_CONNECTION_STRING", provider.ConnectionName); + } + + /// + /// Validates that when the specified connection string is absent from configuration, + /// the factory throws an InvalidOperationException rather than silently continuing. + /// + [Fact] + public void GetScalabilityProvider_MissingConnectionString_ThrowsInvalidOperationException() + { + var emptyConfig = new ConfigurationBuilder().Build(); + var factory = new AzureManagedScalabilityProviderFactory( + emptyConfig, + this.loggerFactory); + + var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "MISSING_CONNECTION", "azureManaged"); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + Assert.Throws(() => factory.GetScalabilityProvider(metadata, triggerMetadata)); + } + + /// + /// Validates that calling GetScalabilityProvider twice with the same connection name and task hub + /// returns the same cached provider instance, avoiding redundant connections per scale decision. + /// + [Fact] + public void GetScalabilityProvider_SameParameters_ReturnsCachedInstance() + { + var factory = new AzureManagedScalabilityProviderFactory( + this.configuration, + this.loggerFactory); + + var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "v3-dtsConnectionMI", "azureManaged"); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + var provider1 = factory.GetScalabilityProvider(metadata, triggerMetadata); + var provider2 = factory.GetScalabilityProvider(metadata, triggerMetadata); + + Assert.Same(provider1, provider2); + } + + /// + /// KEY SCENARIO: When multiple provider factories are registered and trigger metadata specifies + /// storageProvider.type = "azureManaged", DurableTaskScaleExtension.GetScalabilityProviderFactory + /// must select the AzureManagedScalabilityProviderFactory. + /// + [Fact] + public void GetScalabilityProviderFactory_WhenMetadataTypeIsAzureManaged_SelectsAzureManagedFactory() + { + var azureManagedFactory = new AzureManagedScalabilityProviderFactory( + this.configuration, + this.loggerFactory); + + // Create a list of mock Azure Storage provider factories to simulate multiple provider factories. + IScalabilityProviderFactory[] factories = new IScalabilityProviderFactory[] + { + new StubAzureStorageFactory(), + azureManagedFactory, + }; + + var triggerMetadata = TestHelpers.CreateTriggerMetadata("testHub", 5, 10, "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "azureManaged"); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + var logger = this.loggerFactory.CreateLogger("test"); + var selectedFactory = DurableTaskScaleExtension.GetScalabilityProviderFactory(metadata, logger, factories); + + Assert.IsType(selectedFactory); + } + + // Stub used to test factory selection without a real storage emulator. + private class StubAzureStorageFactory : IScalabilityProviderFactory + { + public string Name => "AzureStorage"; + + public string DefaultConnectionName => "AzureWebJobsStorage"; + + public ScalabilityProvider GetScalabilityProvider(DurableTaskMetadata metadata, TriggerMetadata triggerMetadata) + => throw new NotImplementedException(); + } + } +} diff --git a/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs b/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs new file mode 100644 index 000000000..89443bf86 --- /dev/null +++ b/test/ScaleTests/AzureManaged/AzureManagedTargetScalerTests.cs @@ -0,0 +1,133 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using DurableTask.Core; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged; +using Microsoft.Azure.WebJobs.Host.Scale; +using Microsoft.DurableTask.AzureManagedBackend; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests +{ + /// + /// Validates that the target-based autoscaling mechanism produces correct worker counts + /// based on pending/active work item metrics from the Azure Managed backend. + /// + public class AzureManagedTargetScalerTests + { + private readonly ITestOutputHelper output; + + public AzureManagedTargetScalerTests(ITestOutputHelper output) + { + this.output = output; + } + + /// + /// Target scaler calculates correct worker count based on pending orchestrations. + /// Validates that with 20 pending orchestrations and MaxConcurrentOrchestrators=2, + /// the scaler returns 10 workers (20/2 = 10). + /// + [Fact] + public async Task TargetBasedScaling_WithPendingOrchestrations_ReturnsExpectedWorkerCount() + { + var taskHubName = "default"; + var connectionString = TestHelpers.GetAzureManagedConnectionString(); + var options = AzureManagedOrchestrationServiceOptions.FromConnectionString(connectionString); + options.TaskHubName = taskHubName; + options.MaxConcurrentOrchestrationWorkItems = 2; + options.MaxConcurrentActivityWorkItems = 2; + + this.output.WriteLine($"Creating connection to the test DTS TaskHub: {taskHubName}"); + + var loggerFactory = new LoggerFactory(); + using var service = new AzureManagedOrchestrationService(options, loggerFactory); + + var status = new List + { + OrchestrationStatus.Pending, + OrchestrationStatus.Running, + OrchestrationStatus.Suspended, + }; + + var query = new OrchestrationQuery { RuntimeStatus = status }; + var result = await service.GetOrchestrationWithQueryAsync(query, CancellationToken.None); + + int existingCount = result.OrchestrationState?.Count ?? 0; + int orchestrationsToCreate = Math.Max(0, 20 - existingCount); + + this.output.WriteLine($"Found {existingCount} existing orchestrations. Creating {orchestrationsToCreate} new ones."); + + for (int i = 0; i < orchestrationsToCreate; i++) + { + var instance = new OrchestrationInstance + { + InstanceId = $"TestOrchestration_{Guid.NewGuid():N}", + ExecutionId = Guid.NewGuid().ToString(), + }; + + await service.CreateTaskOrchestrationAsync( + new TaskMessage + { + OrchestrationInstance = instance, + Event = new ExecutionStartedEvent(-1, "TestInput") + { + OrchestrationInstance = instance, + Name = "TestOrchestration", + Version = "1.0", + Input = "TestInput", + }, + }); + } + + await Task.Delay(2000); + + var configuration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + { "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", connectionString }, + }) + .Build(); + + var factory = new AzureManagedScalabilityProviderFactory( + configuration, + loggerFactory); + + var triggerMetadata = TestHelpers.CreateTriggerMetadata(taskHubName, 2, 2, "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", "azureManaged"); + var metadata = triggerMetadata.ExtractDurableTaskMetadata(); + + var provider = factory.GetScalabilityProvider(metadata, triggerMetadata); + Assert.True(provider is AzureManagedScalabilityProvider, "Expected AzureManagedScalabilityProvider from factory."); + + bool targetScalerCreated = provider.TryGetTargetScaler( + "functionId", + "TestFunction", + taskHubName, + "DURABLE_TASK_SCHEDULER_CONNECTION_STRING", + out ITargetScaler targetScaler); + + Assert.True(targetScalerCreated); + Assert.NotNull(targetScaler); + Assert.IsType(targetScaler); + + var verifyResult = await service.GetOrchestrationWithQueryAsync(new OrchestrationQuery { RuntimeStatus = status }, CancellationToken.None); + this.output.WriteLine($"Found {verifyResult.OrchestrationState?.Count ?? 0} orchestrations via query"); + + await Task.Delay(3000); + + TargetScalerResult scalerResult = await targetScaler.GetScaleResultAsync(new TargetScalerContext()); + + Assert.NotNull(scalerResult); + this.output.WriteLine($"Target worker count: {scalerResult.TargetWorkerCount}"); + Assert.Equal(10, scalerResult.TargetWorkerCount); + } + } +} diff --git a/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs b/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs index 1e8aa38fd..cf3b3d570 100644 --- a/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs +++ b/test/ScaleTests/DurableTaskScaleConfigurationExtensionsTests.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Linq; +using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureManaged; using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.AzureStorage; using Microsoft.Azure.WebJobs.Host.Config; using Microsoft.Extensions.Configuration; @@ -20,10 +21,8 @@ namespace Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale.Tests public class DurableTaskScaleConfigurationExtensionsTests { /// - /// Scenario: Core service registration in DI container. /// Validates that AddDurableTask() registers IStorageServiceClientProviderFactory. /// Validates that AddDurableTask() registers IScalabilityProviderFactory implementations. - /// Tests the foundational setup required by Scale Controller integration. /// Ensures Scale Controller can resolve storage clients and scalability providers. /// [Fact] @@ -52,6 +51,7 @@ public void AddDurableTask_RegistersRequiredServices() var scalabilityProviderFactories = serviceProvider.GetServices().ToList(); Assert.NotEmpty(scalabilityProviderFactories); Assert.Contains(scalabilityProviderFactories, f => f is AzureStorageScalabilityProviderFactory); + Assert.Contains(scalabilityProviderFactories, f => f is AzureManagedScalabilityProviderFactory); } } diff --git a/test/ScaleTests/TestHelpers.cs b/test/ScaleTests/TestHelpers.cs index 773546c6f..b2fc27a76 100644 --- a/test/ScaleTests/TestHelpers.cs +++ b/test/ScaleTests/TestHelpers.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. using System; -using Microsoft.Azure.WebJobs.Extensions.DurableTask.FunctionsScale; using Microsoft.Azure.WebJobs.Host.Scale; using Newtonsoft.Json.Linq; @@ -30,9 +29,6 @@ public static string GetSqlConnectionString() return sqlConnectionString; } - // If no environment variable is set, throw an exception to ensure tests verify that - // the package correctly reads connection strings from configuration/environment variables. - // This prevents tests from silently using a hardcoded default that doesn't match the actual environment. throw new InvalidOperationException( "SQL connection string not found in environment variables."); } @@ -46,9 +42,6 @@ public static string GetAzureManagedConnectionString() return connectionString; } - // If no environment variable is set, throw an exception to ensure tests verify that - // the package correctly reads connection strings from configuration/environment variables. - // This prevents tests from silently using a hardcoded default that doesn't match the actual environment. throw new InvalidOperationException( "Azure Managed connection string not found in environment variables."); }