Skip to content

[WIP] Flow cancellation token to PendingItem #11196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,13 @@ private void ProcessRegisteredGrpcCallbacks(InboundGrpcEvent message)
next.SetResult(message);
}

private void RegisterCallbackForNextGrpcMessage(MsgType messageType, TimeSpan timeout, int count, Action<InboundGrpcEvent> callback, Action<Exception> faultHandler)
private void RegisterCallbackForNextGrpcMessage(
MsgType messageType,
TimeSpan timeout,
int count,
Action<InboundGrpcEvent> callback,
Action<Exception> faultHandler,
CancellationToken cancellationToken = default)
{
Queue<PendingItem> queue;
lock (_pendingActions)
Expand All @@ -289,8 +295,8 @@ private void RegisterCallbackForNextGrpcMessage(MsgType messageType, TimeSpan ti
for (int i = 0; i < count; i++)
{
var newItem = (i == count - 1) && (timeout != TimeSpan.Zero)
? new PendingItem(callback, faultHandler, timeout)
: new PendingItem(callback, faultHandler);
? new PendingItem(callback, faultHandler, timeout, cancellationToken)
: new PendingItem(callback, faultHandler, cancellationToken);
queue.Enqueue(newItem);
}
}
Expand Down Expand Up @@ -371,8 +377,16 @@ public bool IsChannelReadyForInvocations()

public async Task StartWorkerProcessAsync(CancellationToken cancellationToken)
{
RegisterCallbackForNextGrpcMessage(MsgType.StartStream, _workerConfig.CountOptions.ProcessStartupTimeout, 1, SendWorkerInitRequest, HandleWorkerStartStreamError);
// note: it is important that the ^^^ StartStream is in place *before* we start process the loop, otherwise we get a race condition
cancellationToken.ThrowIfCancellationRequested();

RegisterCallbackForNextGrpcMessage(
MsgType.StartStream,
_workerConfig.CountOptions.ProcessStartupTimeout,
1,
grpcEvent => SendWorkerInitRequest(grpcEvent, cancellationToken),
HandleWorkerStartStreamError,
cancellationToken);
// Note: it is important that the ^^^ StartStream is in place *before* we start process the loop, otherwise we get a race condition
_ = ProcessInbound();

_workerChannelLogger.LogDebug("Initiating Worker Process start up");
Expand Down Expand Up @@ -418,10 +432,12 @@ public async Task<WorkerStatus> GetWorkerStatusAsync()
}

// send capabilities to worker, wait for WorkerInitResponse
internal void SendWorkerInitRequest(GrpcEvent startEvent)
internal void SendWorkerInitRequest(GrpcEvent startEvent, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

_workerChannelLogger.LogDebug("Worker Process started. Received StartStream message");
RegisterCallbackForNextGrpcMessage(MsgType.WorkerInitResponse, _workerConfig.CountOptions.InitializationTimeout, 1, WorkerInitResponse, HandleWorkerInitError);
RegisterCallbackForNextGrpcMessage(MsgType.WorkerInitResponse, _workerConfig.CountOptions.InitializationTimeout, 1, WorkerInitResponse, HandleWorkerInitError, cancellationToken);

WorkerInitRequest initRequest = GetWorkerInitRequest();

Expand Down Expand Up @@ -949,7 +965,7 @@ internal Task<List<RawFunctionMetadata>> SendFunctionMetadataRequest()
if (!_functionMetadataRequestSent)
{
RegisterCallbackForNextGrpcMessage(MsgType.FunctionMetadataResponse, _functionLoadTimeout, 1,
msg => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError);
msg => ProcessFunctionMetadataResponses(msg.Message.FunctionMetadataResponse), HandleWorkerMetadataRequestError);

_workerChannelLogger.LogDebug("Sending WorkerMetadataRequest to {language} worker with worker ID {workerID}", _runtime, _workerId);

Expand Down Expand Up @@ -1749,21 +1765,25 @@ private sealed class PendingItem
{
private readonly Action<InboundGrpcEvent> _callback;
private readonly Action<Exception> _faultHandler;
private CancellationTokenRegistration _ctr;
private CancellationTokenRegistration _timeoutRegistration;
private CancellationTokenRegistration _cancellationRegistration;
private int _state;

public PendingItem(Action<InboundGrpcEvent> callback, Action<Exception> faultHandler)
public PendingItem(Action<InboundGrpcEvent> callback, Action<Exception> faultHandler, CancellationToken cancellationToken = default)
{
_callback = callback;
_faultHandler = faultHandler;

// Register for host shutdown
_cancellationRegistration = cancellationToken.Register(static state => ((PendingItem)state).OnCanceled(), this);
}

public PendingItem(Action<InboundGrpcEvent> callback, Action<Exception> faultHandler, TimeSpan timeout)
: this(callback, faultHandler)
public PendingItem(Action<InboundGrpcEvent> callback, Action<Exception> faultHandler, TimeSpan timeout, CancellationToken cancellationToken = default)
: this(callback, faultHandler, cancellationToken)
{
var cts = new CancellationTokenSource();
cts.CancelAfter(timeout);
_ctr = cts.Token.Register(static state => ((PendingItem)state).OnTimeout(), this);
_timeoutRegistration = cts.Token.Register(static state => ((PendingItem)state).OnTimeout(), this);
}

public bool IsComplete => Volatile.Read(ref _state) != 0;
Expand All @@ -1772,8 +1792,11 @@ public PendingItem(Action<InboundGrpcEvent> callback, Action<Exception> faultHan

public void SetResult(InboundGrpcEvent message)
{
_ctr.Dispose();
_ctr = default;
_timeoutRegistration.Dispose();
_cancellationRegistration.Dispose();
_timeoutRegistration = default;
_cancellationRegistration = default;

if (MakeComplete() && _callback != null)
{
try
Expand Down Expand Up @@ -1813,6 +1836,20 @@ private void OnTimeout()
}
}
}

private void OnCanceled()
{
if (MakeComplete() && _faultHandler != null)
{
try
{
_faultHandler(new OperationCanceledException());
}
catch
{
}
}
}
}
}
}
13 changes: 7 additions & 6 deletions src/WebJobs.Script/Host/WorkerFunctionMetadataProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ internal class WorkerFunctionMetadataProvider : IWorkerFunctionMetadataProvider,
private readonly IEnvironment _environment;
private readonly IWebHostRpcWorkerChannelManager _channelManager;
private readonly IScriptHostManager _scriptHostManager;
private readonly IHostApplicationLifetime _applicationLifetime;
private string _workerRuntime;
private ImmutableArray<FunctionMetadata> _functions;
private IHost _currentJobHost = null;
Expand All @@ -38,14 +39,16 @@ public WorkerFunctionMetadataProvider(
ILogger<WorkerFunctionMetadataProvider> logger,
IEnvironment environment,
IWebHostRpcWorkerChannelManager webHostRpcWorkerChannelManager,
IScriptHostManager scriptHostManager)
IScriptHostManager scriptHostManager,
IHostApplicationLifetime applicationLifetime)
{
_scriptOptions = scriptOptions;
_logger = logger;
_environment = environment;
_channelManager = webHostRpcWorkerChannelManager;
_scriptHostManager = scriptHostManager;
_workerRuntime = _environment.GetEnvironmentVariable(EnvironmentSettingNames.FunctionWorkerRuntime);
_applicationLifetime = applicationLifetime;

_scriptHostManager.ActiveHostChanged += OnHostChanged;
}
Expand Down Expand Up @@ -89,7 +92,7 @@ public async Task<FunctionMetadataResult> GetFunctionMetadataAsync(IEnumerable<R
if (IsJobHostStarting())
{
_logger.LogDebug("JobHost is starting with state '{State}'. Initializing worker channel.", _scriptHostManager.State);
await _channelManager.InitializeChannelAsync(workerConfigs, _workerRuntime);
await _channelManager.InitializeChannelAsync(workerConfigs, _workerRuntime, _applicationLifetime.ApplicationStopping);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my idea was to pass the _applicationLifetime.ApplicationStopping CT to StartWorkerProcessAsync -> PendingItem. (it doesn't have to come through WorkerFunctionMetadataProvider so let me know if you have better ideas)

The problem I am seeing with this is that when I cancel (exit the host) we don't see CT triggered until after the timeout period of PendingItem has run its course. We are blocking in a weird way and we need to get around that block for the CT to flow properly.

cc: @jviau @fabiocav

}
else
{
Expand Down Expand Up @@ -167,10 +170,8 @@ private bool IsJobHostStarting()
// the host has not completely started, it means that it is still in the process of starting.
if (_currentJobHost is not null && _scriptHostManager.State == ScriptHostState.Error)
{
var lifetime = _currentJobHost.Services?.GetService<IHostApplicationLifetime>();

if (lifetime is not null &&
!lifetime.ApplicationStarted.IsCancellationRequested)
if (_applicationLifetime is not null &&
!_applicationLifetime.ApplicationStarted.IsCancellationRequested)
{
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ internal async Task InitializeJobhostLanguageWorkerChannelAsync(IEnumerable<stri
internal async Task InitializeJobhostLanguageWorkerChannelAsync(int attemptCount, string language) =>
await InitializeJobhostLanguageWorkerChannelAsync(attemptCount, new[] { language });

internal async Task InitializeJobhostLanguageWorkerChannelAsync(int attemptCount, IEnumerable<string> languages)
internal async Task InitializeJobhostLanguageWorkerChannelAsync(int attemptCount, IEnumerable<string> languages, CancellationToken cancellationToken = default)
{
foreach (string language in languages)
{
var rpcWorkerChannel = _rpcWorkerChannelFactory.Create(_scriptOptions.RootScriptPath, language, _metricsLogger, attemptCount, _workerConfigs);
_jobHostLanguageWorkerChannelManager.AddChannel(rpcWorkerChannel, language);
await rpcWorkerChannel.StartWorkerProcessAsync();
await rpcWorkerChannel.StartWorkerProcessAsync(cancellationToken);
_logger.LogDebug("Adding jobhost language worker channel for runtime: {language}. workerId:{id}", language, rpcWorkerChannel.Id);

// if the worker is indexing, we will not have function metadata yet. So, we cannot set up invocation buffers or send load requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Azure.WebJobs.Script.Workers.Rpc
{
public interface IWebHostRpcWorkerChannelManager
{
Task<IRpcWorkerChannel> InitializeChannelAsync(IEnumerable<RpcWorkerConfig> workerConfigs, string language);
Task<IRpcWorkerChannel> InitializeChannelAsync(IEnumerable<RpcWorkerConfig> workerConfigs, string language, CancellationToken cancellationToken = default);

IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.AppService.Proxy.Common.Infra;
using Microsoft.Azure.WebJobs.Script.Config;
Expand Down Expand Up @@ -72,22 +73,24 @@ public WebHostRpcWorkerChannelManager(IScriptEventManager eventManager,
_shutdownStandbyWorkerChannels = _shutdownStandbyWorkerChannels.Debounce(milliseconds: 5000);
}

public Task<IRpcWorkerChannel> InitializeChannelAsync(IEnumerable<RpcWorkerConfig> workerConfigs, string runtime)
public Task<IRpcWorkerChannel> InitializeChannelAsync(IEnumerable<RpcWorkerConfig> workerConfigs, string runtime, CancellationToken cancellationToken = default)
{
_logger?.LogDebug("Initializing language worker channel for runtime:{runtime}", runtime);
return InitializeLanguageWorkerChannel(workerConfigs, runtime, _applicationHostOptions.CurrentValue.ScriptPath);
return InitializeLanguageWorkerChannel(workerConfigs, runtime, _applicationHostOptions.CurrentValue.ScriptPath, cancellationToken);
}

internal async Task<IRpcWorkerChannel> InitializeLanguageWorkerChannel(IEnumerable<RpcWorkerConfig> workerConfigs, string runtime, string scriptRootPath)
internal async Task<IRpcWorkerChannel> InitializeLanguageWorkerChannel(IEnumerable<RpcWorkerConfig> workerConfigs, string runtime, string scriptRootPath, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();

IRpcWorkerChannel rpcWorkerChannel = null;
string workerId = Guid.NewGuid().ToString();
_logger.LogDebug("Creating language worker channel for runtime:{runtime}", runtime);
_logger.LogWarning("Creating language worker channel for runtime:{runtime}", runtime);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

try
{
rpcWorkerChannel = _rpcWorkerChannelFactory.Create(scriptRootPath, runtime, _metricsLogger, 0, workerConfigs);
AddOrUpdateWorkerChannels(runtime, rpcWorkerChannel);
await rpcWorkerChannel.StartWorkerProcessAsync().ContinueWith(processStartTask =>
await rpcWorkerChannel.StartWorkerProcessAsync(cancellationToken).ContinueWith(processStartTask =>
{
if (processStartTask.Status == TaskStatus.RanToCompletion)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Script.Description;
using Microsoft.Azure.WebJobs.Script.Workers.Rpc;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
Expand All @@ -28,13 +29,15 @@ public WorkerFunctionMetadataProviderTests()
var mockEnvironment = new Mock<IEnvironment>();
var mockChannelManager = new Mock<IWebHostRpcWorkerChannelManager>();
var mockScriptHostManager = new Mock<IScriptHostManager>();
var mockLifetime = new Mock<IHostApplicationLifetime>();

_workerFunctionMetadataProvider = new WorkerFunctionMetadataProvider(
mockScriptOptions.Object,
mockLogger.Object,
mockEnvironment.Object,
mockChannelManager.Object,
mockScriptHostManager.Object);
mockScriptHostManager.Object,
mockLifetime.Object);
}

[Fact]
Expand Down Expand Up @@ -188,6 +191,8 @@ public async void ValidateFunctionMetadata_Logging()
var mockScriptHostManager = new Mock<IScriptHostManager>();
mockScriptHostManager.Setup(m => m.State).Returns(ScriptHostState.Running);

var mockLifetime = new Mock<IHostApplicationLifetime>();

var mockWebHostRpcWorkerChannelManager = new Mock<IWebHostRpcWorkerChannelManager>();
mockWebHostRpcWorkerChannelManager.Setup(m => m.GetChannels(It.IsAny<string>())).Returns(() => new Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>>
{
Expand All @@ -197,7 +202,7 @@ public async void ValidateFunctionMetadata_Logging()
environment.SetEnvironmentVariable(EnvironmentSettingNames.FunctionWorkerRuntime, "node");

var workerFunctionMetadataProvider = new WorkerFunctionMetadataProvider(optionsMonitor, logger, SystemEnvironment.Instance,
mockWebHostRpcWorkerChannelManager.Object, mockScriptHostManager.Object);
mockWebHostRpcWorkerChannelManager.Object, mockScriptHostManager.Object, mockLifetime.Object);
await workerFunctionMetadataProvider.GetFunctionMetadataAsync(workerConfigs, false);

var traces = logger.GetLogMessages();
Expand Down Expand Up @@ -230,6 +235,7 @@ public async Task GetFunctionMetadataAsync_Idempotent()
var mockChannelManager = new Mock<IWebHostRpcWorkerChannelManager>(MockBehavior.Strict);
var mockScriptHostManager = new Mock<IScriptHostManager>(MockBehavior.Strict);
var mockOptionsMonitor = new Mock<IOptionsMonitor<ScriptApplicationHostOptions>>(MockBehavior.Strict);
var mockLifetime = new Mock<IHostApplicationLifetime>(MockBehavior.Strict);
var scriptOptions = new ScriptApplicationHostOptions
{
IsFileSystemReadOnly = true
Expand Down Expand Up @@ -265,7 +271,8 @@ public async Task GetFunctionMetadataAsync_Idempotent()
NullLogger<WorkerFunctionMetadataProvider>.Instance,
testEnvironment,
mockChannelManager.Object,
mockScriptHostManager.Object);
mockScriptHostManager.Object,
mockLifetime.Object);

var workerConfigs = new List<RpcWorkerConfig>();

Expand Down
Loading
Loading