From 8bba9c4b8020a38a2872b0831fbc09059696515e Mon Sep 17 00:00:00 2001 From: Lilian Kasem Date: Fri, 18 Jul 2025 11:32:57 -0700 Subject: [PATCH] draft cancellation flows to PendingItem --- .../Channel/GrpcWorkerChannel.cs | 67 ++++++++++++++----- .../Host/WorkerFunctionMetadataProvider.cs | 13 ++-- .../RpcFunctionInvocationDispatcher.cs | 4 +- .../Rpc/IWebHostRpcWorkerChannelManager.cs | 3 +- .../Rpc/WebHostRpcWorkerChannelManager.cs | 13 ++-- .../WorkerFunctionMetadataProviderTests.cs | 13 +++- .../Rpc/RpcInitializationServiceTests.cs | 20 +++--- .../Rpc/TestRpcWorkerChannelManager.cs | 5 +- 8 files changed, 94 insertions(+), 44 deletions(-) diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs index 7b897bfea0..8d6ef017b9 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs @@ -267,7 +267,13 @@ private void ProcessRegisteredGrpcCallbacks(InboundGrpcEvent message) next.SetResult(message); } - private void RegisterCallbackForNextGrpcMessage(MsgType messageType, TimeSpan timeout, int count, Action callback, Action faultHandler) + private void RegisterCallbackForNextGrpcMessage( + MsgType messageType, + TimeSpan timeout, + int count, + Action callback, + Action faultHandler, + CancellationToken cancellationToken = default) { Queue queue; lock (_pendingActions) @@ -289,8 +295,8 @@ private void RegisterCallbackForNextGrpcMessage(MsgType messageType, TimeSpan ti for (int i = 0; i < count; i++) { var newItem = (i == count - 1) && (timeout != TimeSpan.Zero) - ? new PendingItem(callback, faultHandler, timeout) - : new PendingItem(callback, faultHandler); + ? new PendingItem(callback, faultHandler, timeout, cancellationToken) + : new PendingItem(callback, faultHandler, cancellationToken); queue.Enqueue(newItem); } } @@ -371,8 +377,16 @@ public bool IsChannelReadyForInvocations() public async Task StartWorkerProcessAsync(CancellationToken cancellationToken) { - RegisterCallbackForNextGrpcMessage(MsgType.StartStream, _workerConfig.CountOptions.ProcessStartupTimeout, 1, SendWorkerInitRequest, HandleWorkerStartStreamError); - // note: it is important that the ^^^ StartStream is in place *before* we start process the loop, otherwise we get a race condition + cancellationToken.ThrowIfCancellationRequested(); + + RegisterCallbackForNextGrpcMessage( + MsgType.StartStream, + _workerConfig.CountOptions.ProcessStartupTimeout, + 1, + grpcEvent => SendWorkerInitRequest(grpcEvent, cancellationToken), + HandleWorkerStartStreamError, + cancellationToken); + // Note: it is important that the ^^^ StartStream is in place *before* we start process the loop, otherwise we get a race condition _ = ProcessInbound(); _workerChannelLogger.LogDebug("Initiating Worker Process start up"); @@ -418,10 +432,12 @@ public async Task GetWorkerStatusAsync() } // send capabilities to worker, wait for WorkerInitResponse - internal void SendWorkerInitRequest(GrpcEvent startEvent) + internal void SendWorkerInitRequest(GrpcEvent startEvent, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + _workerChannelLogger.LogDebug("Worker Process started. Received StartStream message"); - RegisterCallbackForNextGrpcMessage(MsgType.WorkerInitResponse, _workerConfig.CountOptions.InitializationTimeout, 1, WorkerInitResponse, HandleWorkerInitError); + RegisterCallbackForNextGrpcMessage(MsgType.WorkerInitResponse, _workerConfig.CountOptions.InitializationTimeout, 1, WorkerInitResponse, HandleWorkerInitError, cancellationToken); WorkerInitRequest initRequest = GetWorkerInitRequest(); @@ -949,7 +965,7 @@ internal Task> SendFunctionMetadataRequest() if (!_functionMetadataRequestSent) { RegisterCallbackForNextGrpcMessage(MsgType.FunctionMetadataResponse, _functionLoadTimeout, 1, - msg => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError); + msg => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError); _workerChannelLogger.LogDebug("Sending WorkerMetadataRequest to {language} worker with worker ID {workerID}", _runtime, _workerId); @@ -1749,21 +1765,25 @@ private sealed class PendingItem { private readonly Action _callback; private readonly Action _faultHandler; - private CancellationTokenRegistration _ctr; + private CancellationTokenRegistration _timeoutRegistration; + private CancellationTokenRegistration _cancellationRegistration; private int _state; - public PendingItem(Action callback, Action faultHandler) + public PendingItem(Action callback, Action faultHandler, CancellationToken cancellationToken = default) { _callback = callback; _faultHandler = faultHandler; + + // Register for host shutdown + _cancellationRegistration = cancellationToken.Register(static state => ((PendingItem)state).OnCanceled(), this); } - public PendingItem(Action callback, Action faultHandler, TimeSpan timeout) - : this(callback, faultHandler) + public PendingItem(Action callback, Action faultHandler, TimeSpan timeout, CancellationToken cancellationToken = default) + : this(callback, faultHandler, cancellationToken) { var cts = new CancellationTokenSource(); cts.CancelAfter(timeout); - _ctr = cts.Token.Register(static state => ((PendingItem)state).OnTimeout(), this); + _timeoutRegistration = cts.Token.Register(static state => ((PendingItem)state).OnTimeout(), this); } public bool IsComplete => Volatile.Read(ref _state) != 0; @@ -1772,8 +1792,11 @@ public PendingItem(Action callback, Action faultHan public void SetResult(InboundGrpcEvent message) { - _ctr.Dispose(); - _ctr = default; + _timeoutRegistration.Dispose(); + _cancellationRegistration.Dispose(); + _timeoutRegistration = default; + _cancellationRegistration = default; + if (MakeComplete() && _callback != null) { try @@ -1813,6 +1836,20 @@ private void OnTimeout() } } } + + private void OnCanceled() + { + if (MakeComplete() && _faultHandler != null) + { + try + { + _faultHandler(new OperationCanceledException()); + } + catch + { + } + } + } } } } \ No newline at end of file diff --git a/src/WebJobs.Script/Host/WorkerFunctionMetadataProvider.cs b/src/WebJobs.Script/Host/WorkerFunctionMetadataProvider.cs index 22c2c7a0b3..b6a582fc67 100644 --- a/src/WebJobs.Script/Host/WorkerFunctionMetadataProvider.cs +++ b/src/WebJobs.Script/Host/WorkerFunctionMetadataProvider.cs @@ -29,6 +29,7 @@ internal class WorkerFunctionMetadataProvider : IWorkerFunctionMetadataProvider, private readonly IEnvironment _environment; private readonly IWebHostRpcWorkerChannelManager _channelManager; private readonly IScriptHostManager _scriptHostManager; + private readonly IHostApplicationLifetime _applicationLifetime; private string _workerRuntime; private ImmutableArray _functions; private IHost _currentJobHost = null; @@ -38,7 +39,8 @@ public WorkerFunctionMetadataProvider( ILogger logger, IEnvironment environment, IWebHostRpcWorkerChannelManager webHostRpcWorkerChannelManager, - IScriptHostManager scriptHostManager) + IScriptHostManager scriptHostManager, + IHostApplicationLifetime applicationLifetime) { _scriptOptions = scriptOptions; _logger = logger; @@ -46,6 +48,7 @@ public WorkerFunctionMetadataProvider( _channelManager = webHostRpcWorkerChannelManager; _scriptHostManager = scriptHostManager; _workerRuntime = _environment.GetEnvironmentVariable(EnvironmentSettingNames.FunctionWorkerRuntime); + _applicationLifetime = applicationLifetime; _scriptHostManager.ActiveHostChanged += OnHostChanged; } @@ -89,7 +92,7 @@ public async Task GetFunctionMetadataAsync(IEnumerable(); - - if (lifetime is not null && - !lifetime.ApplicationStarted.IsCancellationRequested) + if (_applicationLifetime is not null && + !_applicationLifetime.ApplicationStarted.IsCancellationRequested) { return true; } diff --git a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs index 5025c58591..fc804f68ce 100644 --- a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs +++ b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs @@ -158,13 +158,13 @@ internal async Task InitializeJobhostLanguageWorkerChannelAsync(IEnumerable await InitializeJobhostLanguageWorkerChannelAsync(attemptCount, new[] { language }); - internal async Task InitializeJobhostLanguageWorkerChannelAsync(int attemptCount, IEnumerable languages) + internal async Task InitializeJobhostLanguageWorkerChannelAsync(int attemptCount, IEnumerable languages, CancellationToken cancellationToken = default) { foreach (string language in languages) { var rpcWorkerChannel = _rpcWorkerChannelFactory.Create(_scriptOptions.RootScriptPath, language, _metricsLogger, attemptCount, _workerConfigs); _jobHostLanguageWorkerChannelManager.AddChannel(rpcWorkerChannel, language); - await rpcWorkerChannel.StartWorkerProcessAsync(); + await rpcWorkerChannel.StartWorkerProcessAsync(cancellationToken); _logger.LogDebug("Adding jobhost language worker channel for runtime: {language}. workerId:{id}", language, rpcWorkerChannel.Id); // if the worker is indexing, we will not have function metadata yet. So, we cannot set up invocation buffers or send load requests diff --git a/src/WebJobs.Script/Workers/Rpc/IWebHostRpcWorkerChannelManager.cs b/src/WebJobs.Script/Workers/Rpc/IWebHostRpcWorkerChannelManager.cs index 9df04d3d62..2b27d46c0d 100644 --- a/src/WebJobs.Script/Workers/Rpc/IWebHostRpcWorkerChannelManager.cs +++ b/src/WebJobs.Script/Workers/Rpc/IWebHostRpcWorkerChannelManager.cs @@ -3,13 +3,14 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.WebJobs.Script.Workers.Rpc { public interface IWebHostRpcWorkerChannelManager { - Task InitializeChannelAsync(IEnumerable workerConfigs, string language); + Task InitializeChannelAsync(IEnumerable workerConfigs, string language, CancellationToken cancellationToken = default); IDictionary> GetChannels(string language); diff --git a/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs b/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs index 3dd214d343..0154b0d918 100644 --- a/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs +++ b/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Reactive.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.AppService.Proxy.Common.Infra; using Microsoft.Azure.WebJobs.Script.Config; @@ -72,22 +73,24 @@ public WebHostRpcWorkerChannelManager(IScriptEventManager eventManager, _shutdownStandbyWorkerChannels = _shutdownStandbyWorkerChannels.Debounce(milliseconds: 5000); } - public Task InitializeChannelAsync(IEnumerable workerConfigs, string runtime) + public Task InitializeChannelAsync(IEnumerable workerConfigs, string runtime, CancellationToken cancellationToken = default) { _logger?.LogDebug("Initializing language worker channel for runtime:{runtime}", runtime); - return InitializeLanguageWorkerChannel(workerConfigs, runtime, _applicationHostOptions.CurrentValue.ScriptPath); + return InitializeLanguageWorkerChannel(workerConfigs, runtime, _applicationHostOptions.CurrentValue.ScriptPath, cancellationToken); } - internal async Task InitializeLanguageWorkerChannel(IEnumerable workerConfigs, string runtime, string scriptRootPath) + internal async Task InitializeLanguageWorkerChannel(IEnumerable workerConfigs, string runtime, string scriptRootPath, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + IRpcWorkerChannel rpcWorkerChannel = null; string workerId = Guid.NewGuid().ToString(); - _logger.LogDebug("Creating language worker channel for runtime:{runtime}", runtime); + _logger.LogWarning("Creating language worker channel for runtime:{runtime}", runtime); try { rpcWorkerChannel = _rpcWorkerChannelFactory.Create(scriptRootPath, runtime, _metricsLogger, 0, workerConfigs); AddOrUpdateWorkerChannels(runtime, rpcWorkerChannel); - await rpcWorkerChannel.StartWorkerProcessAsync().ContinueWith(processStartTask => + await rpcWorkerChannel.StartWorkerProcessAsync(cancellationToken).ContinueWith(processStartTask => { if (processStartTask.Status == TaskStatus.RanToCompletion) { diff --git a/test/WebJobs.Script.Tests/WorkerFunctionMetadataProviderTests.cs b/test/WebJobs.Script.Tests/WorkerFunctionMetadataProviderTests.cs index 7fc0737a8e..29dfc09e38 100644 --- a/test/WebJobs.Script.Tests/WorkerFunctionMetadataProviderTests.cs +++ b/test/WebJobs.Script.Tests/WorkerFunctionMetadataProviderTests.cs @@ -9,6 +9,7 @@ using System.Threading.Tasks; using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Workers.Rpc; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; @@ -28,13 +29,15 @@ public WorkerFunctionMetadataProviderTests() var mockEnvironment = new Mock(); var mockChannelManager = new Mock(); var mockScriptHostManager = new Mock(); + var mockLifetime = new Mock(); _workerFunctionMetadataProvider = new WorkerFunctionMetadataProvider( mockScriptOptions.Object, mockLogger.Object, mockEnvironment.Object, mockChannelManager.Object, - mockScriptHostManager.Object); + mockScriptHostManager.Object, + mockLifetime.Object); } [Fact] @@ -188,6 +191,8 @@ public async void ValidateFunctionMetadata_Logging() var mockScriptHostManager = new Mock(); mockScriptHostManager.Setup(m => m.State).Returns(ScriptHostState.Running); + var mockLifetime = new Mock(); + var mockWebHostRpcWorkerChannelManager = new Mock(); mockWebHostRpcWorkerChannelManager.Setup(m => m.GetChannels(It.IsAny())).Returns(() => new Dictionary> { @@ -197,7 +202,7 @@ public async void ValidateFunctionMetadata_Logging() environment.SetEnvironmentVariable(EnvironmentSettingNames.FunctionWorkerRuntime, "node"); var workerFunctionMetadataProvider = new WorkerFunctionMetadataProvider(optionsMonitor, logger, SystemEnvironment.Instance, - mockWebHostRpcWorkerChannelManager.Object, mockScriptHostManager.Object); + mockWebHostRpcWorkerChannelManager.Object, mockScriptHostManager.Object, mockLifetime.Object); await workerFunctionMetadataProvider.GetFunctionMetadataAsync(workerConfigs, false); var traces = logger.GetLogMessages(); @@ -230,6 +235,7 @@ public async Task GetFunctionMetadataAsync_Idempotent() var mockChannelManager = new Mock(MockBehavior.Strict); var mockScriptHostManager = new Mock(MockBehavior.Strict); var mockOptionsMonitor = new Mock>(MockBehavior.Strict); + var mockLifetime = new Mock(MockBehavior.Strict); var scriptOptions = new ScriptApplicationHostOptions { IsFileSystemReadOnly = true @@ -265,7 +271,8 @@ public async Task GetFunctionMetadataAsync_Idempotent() NullLogger.Instance, testEnvironment, mockChannelManager.Object, - mockScriptHostManager.Object); + mockScriptHostManager.Object, + mockLifetime.Object); var workerConfigs = new List(); diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/RpcInitializationServiceTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/RpcInitializationServiceTests.cs index da26a178c0..7adcbf7534 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/RpcInitializationServiceTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/RpcInitializationServiceTests.cs @@ -38,7 +38,7 @@ public RpcInitializationServiceTests() _workerOptionsMonitor = TestHelpers.CreateOptionsMonitor(TestHelpers.GetTestLanguageWorkerOptions()); IRpcWorkerChannel testLanguageWorkerChannel = new TestRpcWorkerChannel(Guid.NewGuid().ToString(), RpcWorkerConstants.NodeLanguageWorkerName); - _mockLanguageWorkerChannelManager.Setup(m => m.InitializeChannelAsync(It.IsAny>(), It.IsAny())) + _mockLanguageWorkerChannelManager.Setup(m => m.InitializeChannelAsync(It.IsAny>(), It.IsAny(), It.IsAny())) .Returns(Task.FromResult(testLanguageWorkerChannel)); } @@ -53,7 +53,7 @@ public async Task RpcInitializationService_AppOffline() offlineFilePath = TestHelpers.CreateOfflineFile(); _rpcInitializationService = new RpcInitializationService(_optionsMonitor, mockEnvironment.Object, testRpcServer, _mockLanguageWorkerChannelManager.Object, _logger, _workerOptionsMonitor); await _rpcInitializationService.StartAsync(CancellationToken.None); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName), Times.Never); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName, It.IsAny()), Times.Never); Assert.DoesNotContain("testserver", testRpcServer.Uri.ToString()); await testRpcServer.ShutdownAsync(); } @@ -71,7 +71,7 @@ public async Task RpcInitializationService_Initializes_RpcServerAndChannels_Work mockEnvironment.Setup(p => p.GetEnvironmentVariable(RpcWorkerConstants.FunctionWorkerRuntimeSettingName)).Returns(RpcWorkerConstants.NodeLanguageWorkerName); _rpcInitializationService = new RpcInitializationService(_optionsMonitor, mockEnvironment.Object, testRpcServer, _mockLanguageWorkerChannelManager.Object, _logger, _workerOptionsMonitor); await _rpcInitializationService.StartAsync(CancellationToken.None); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName), Times.Never); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName, It.IsAny()), Times.Never); Assert.Contains("testserver", testRpcServer.Uri.ToString()); await testRpcServer.ShutdownAsync(); } @@ -87,7 +87,7 @@ public async Task RpcInitializationService_Initializes_RpcServerAndChannels_WebH _rpcInitializationService = new RpcInitializationService(_optionsMonitor, mockEnvironment.Object, testRpcServer, _mockLanguageWorkerChannelManager.Object, _logger, _workerOptionsMonitor); await _rpcInitializationService.StartAsync(CancellationToken.None); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName), Times.Once); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName, It.IsAny()), Times.Once); Assert.Contains("testserver", testRpcServer.Uri.ToString()); await testRpcServer.ShutdownAsync(); } @@ -112,9 +112,9 @@ public async Task RpcInitializationService_Initializes_RpcServer_RpcChannels_Pla _rpcInitializationService = new RpcInitializationService(_optionsMonitor, mockEnvironment.Object, testRpcServer, _mockLanguageWorkerChannelManager.Object, _logger, _workerOptionsMonitor); await _rpcInitializationService.StartAsync(CancellationToken.None); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName), Times.Once); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName), Times.Once); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.PythonLanguageWorkerName), Times.Once); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName, It.IsAny()), Times.Once); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName, It.IsAny()), Times.Once); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.PythonLanguageWorkerName, It.IsAny()), Times.Once); Assert.Contains("testserver", testRpcServer.Uri.ToString()); await testRpcServer.ShutdownAsync(); } @@ -129,9 +129,9 @@ public async Task RpcInitializationService_Initializes_RpcServer_DoesNot_Initial _rpcInitializationService = new RpcInitializationService(_optionsMonitor, mockEnvironment.Object, testRpcServer, _mockLanguageWorkerChannelManager.Object, _logger, _workerOptionsMonitor); await _rpcInitializationService.StartAsync(CancellationToken.None); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName), Times.Never); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName), Times.Never); - _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.PythonLanguageWorkerName), Times.Never); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.NodeLanguageWorkerName, It.IsAny()), Times.Never); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.JavaLanguageWorkerName, It.IsAny()), Times.Never); + _mockLanguageWorkerChannelManager.Verify(m => m.InitializeChannelAsync(It.IsAny>(), RpcWorkerConstants.PythonLanguageWorkerName, It.IsAny()), Times.Never); Assert.Contains("testserver", testRpcServer.Uri.ToString()); await testRpcServer.ShutdownAsync(); } diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs index fe3a3187e9..1b6e030c4c 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Diagnostics; @@ -44,7 +45,7 @@ public IDictionary> GetChannels( return null; } - public async Task InitializeChannelAsync(IEnumerable workerConfigs, string language) + public async Task InitializeChannelAsync(IEnumerable workerConfigs, string language, CancellationToken cancellationToken = default) { var metricsLogger = new Mock(); IRpcWorkerChannel workerChannel = _testLanguageWorkerChannelFactory.Create(_scriptRootPath, language, metricsLogger.Object, 0, TestHelpers.GetTestWorkerConfigs()); @@ -58,7 +59,7 @@ public async Task InitializeChannelAsync(IEnumerable()); } - await workerChannel.StartWorkerProcessAsync().ContinueWith(processStartTask => + await workerChannel.StartWorkerProcessAsync(cancellationToken).ContinueWith(processStartTask => { if (processStartTask.Status == TaskStatus.RanToCompletion) {