Skip to content

Commit b9807eb

Browse files
authored
Fix race condition during language worker specialization (#4794)
1 parent 11d5e52 commit b9807eb

13 files changed

+184
-276
lines changed

src/WebJobs.Script/Description/Rpc/WorkerLanguageInvoker.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ protected override async Task<object> InvokeCore(object[] parameters, FunctionIn
7575

7676
ScriptInvocationResult result;
7777
_logger.LogDebug($"Sending invocation id:{invocationId}");
78-
_functionDispatcher.Invoke(invocationContext);
78+
await _functionDispatcher.InvokeAsync(invocationContext);
7979
result = await invocationContext.ResultSource.Task;
8080

8181
await BindOutputsAsync(triggerValue, context.Binder, result);

src/WebJobs.Script/Rpc/FunctionRegistration/FunctionDispatcher.cs

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ internal async void InitializeWebhostLanguageWorkerChannel()
132132
workerChannel.SendFunctionLoadRequests();
133133
}
134134

135-
internal void ShutdownWebhostLanguageWorkerChannels()
135+
internal async void ShutdownWebhostLanguageWorkerChannels()
136136
{
137137
_logger.LogDebug("{workerRuntimeConstant}={value}. Will shutdown all the worker channels that started in placeholder mode", LanguageWorkerConstants.FunctionWorkerRuntimeSettingName, _workerRuntime);
138-
_webHostLanguageWorkerChannelManager.ShutdownChannels();
138+
await _webHostLanguageWorkerChannelManager?.ShutdownChannelsAsync();
139139
}
140140

141141
private void StartWorkerProcesses(int startIndex, Action startAction)
@@ -187,16 +187,20 @@ public async Task InitializeAsync(IEnumerable<FunctionMetadata> functions)
187187
if (Utility.IsSupportedRuntime(_workerRuntime, _workerConfigs))
188188
{
189189
State = FunctionDispatcherState.Initializing;
190-
IEnumerable<ILanguageWorkerChannel> initializedChannels = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
191-
if (initializedChannels != null)
190+
Dictionary<string, TaskCompletionSource<ILanguageWorkerChannel>> webhostLanguageWorkerChannels = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
191+
if (webhostLanguageWorkerChannels != null)
192192
{
193-
foreach (var initializedChannel in initializedChannels)
193+
foreach (string workerId in webhostLanguageWorkerChannels.Keys)
194194
{
195-
_logger.LogDebug("Found initialized language worker channel for runtime: {workerRuntime} workerId:{workerId}", _workerRuntime, initializedChannel.Id);
196-
initializedChannel.SetupFunctionInvocationBuffers(_functions);
197-
initializedChannel.SendFunctionLoadRequests();
195+
if (webhostLanguageWorkerChannels.TryGetValue(workerId, out TaskCompletionSource<ILanguageWorkerChannel> initializedLanguageWorkerChannelTask))
196+
{
197+
_logger.LogDebug("Found initialized language worker channel for runtime: {workerRuntime} workerId:{workerId}", _workerRuntime, workerId);
198+
ILanguageWorkerChannel initializedLanguageWorkerChannel = await initializedLanguageWorkerChannelTask.Task;
199+
initializedLanguageWorkerChannel.SetupFunctionInvocationBuffers(_functions);
200+
initializedLanguageWorkerChannel.SendFunctionLoadRequests();
201+
}
198202
}
199-
StartWorkerProcesses(initializedChannels.Count(), InitializeWebhostLanguageWorkerChannel);
203+
StartWorkerProcesses(webhostLanguageWorkerChannels.Count(), InitializeWebhostLanguageWorkerChannel);
200204
State = FunctionDispatcherState.Initialized;
201205
}
202206
else
@@ -207,11 +211,11 @@ public async Task InitializeAsync(IEnumerable<FunctionMetadata> functions)
207211
}
208212
}
209213

210-
public void Invoke(ScriptInvocationContext invocationContext)
214+
public async Task InvokeAsync(ScriptInvocationContext invocationContext)
211215
{
212216
try
213217
{
214-
IEnumerable<ILanguageWorkerChannel> workerChannels = GetInitializedWorkerChannels();
218+
IEnumerable<ILanguageWorkerChannel> workerChannels = await GetInitializedWorkerChannelsAsync();
215219
var languageWorkerChannel = _functionDispatcherLoadBalancer.GetLanguageWorkerChannel(workerChannels, _maxProcessCount);
216220
if (languageWorkerChannel.FunctionInputBuffers.TryGetValue(invocationContext.FunctionMetadata.FunctionId, out BufferBlock<ScriptInvocationContext> bufferBlock))
217221
{
@@ -229,9 +233,21 @@ public void Invoke(ScriptInvocationContext invocationContext)
229233
}
230234
}
231235

232-
internal IEnumerable<ILanguageWorkerChannel> GetInitializedWorkerChannels()
236+
internal async Task<IEnumerable<ILanguageWorkerChannel>> GetInitializedWorkerChannelsAsync()
233237
{
234-
IEnumerable<ILanguageWorkerChannel> webhostChannels = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
238+
Dictionary<string, TaskCompletionSource<ILanguageWorkerChannel>> webhostChannelDictionary = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
239+
List<ILanguageWorkerChannel> webhostChannels = null;
240+
if (webhostChannelDictionary != null)
241+
{
242+
webhostChannels = new List<ILanguageWorkerChannel>();
243+
foreach (string workerId in webhostChannelDictionary.Keys)
244+
{
245+
if (webhostChannelDictionary.TryGetValue(workerId, out TaskCompletionSource<ILanguageWorkerChannel> initializedLanguageWorkerChannelTask))
246+
{
247+
webhostChannels.Add(await initializedLanguageWorkerChannelTask.Task);
248+
}
249+
}
250+
}
235251
IEnumerable<ILanguageWorkerChannel> workerChannels = webhostChannels == null ? _jobHostLanguageWorkerChannelManager.GetChannels() : webhostChannels.Union(_jobHostLanguageWorkerChannelManager.GetChannels());
236252
IEnumerable<ILanguageWorkerChannel> initializedWorkers = workerChannels.Where(ch => ch.State == LanguageWorkerChannelState.Initialized);
237253
if (initializedWorkers.Count() > _maxProcessCount)
@@ -247,7 +263,7 @@ public async void WorkerError(WorkerErrorEvent workerError)
247263
{
248264
_logger.LogDebug("Handling WorkerErrorEvent for runtime:{runtime}, workerId:{workerId}", workerError.Language, workerError.WorkerId);
249265
_languageWorkerErrors.Add(workerError.Exception);
250-
bool isPreInitializedChannel = _webHostLanguageWorkerChannelManager.ShutdownChannelIfExists(workerError.Language, workerError.WorkerId);
266+
bool isPreInitializedChannel = await _webHostLanguageWorkerChannelManager.ShutdownChannelIfExistsAsync(workerError.Language, workerError.WorkerId);
251267
if (!isPreInitializedChannel)
252268
{
253269
_logger.LogDebug("Disposing errored channel for workerId: {channelId}, for runtime:{language}", workerError.WorkerId, workerError.Language);
@@ -257,8 +273,11 @@ public async void WorkerError(WorkerErrorEvent workerError)
257273
_jobHostLanguageWorkerChannelManager.DisposeAndRemoveChannel(erroredChannel);
258274
}
259275
}
260-
_logger.LogDebug("Restarting worker channel for runtime:{runtime}", workerError.Language);
261-
await RestartWorkerChannel(workerError.Language, workerError.WorkerId);
276+
if (_workerRuntime.Equals(workerError.Language, StringComparison.InvariantCultureIgnoreCase))
277+
{
278+
_logger.LogDebug("Restarting worker channel for runtime:{runtime}", workerError.Language);
279+
await RestartWorkerChannel(workerError.Language, workerError.WorkerId);
280+
}
262281
}
263282
}
264283

@@ -279,6 +298,7 @@ protected virtual void Dispose(bool disposing)
279298
{
280299
if (!_disposed && disposing)
281300
{
301+
_logger.LogDebug("Disposing FunctionDispatcher");
282302
_workerErrorSubscription.Dispose();
283303
_processStartCancellationToken.Cancel();
284304
_processStartCancellationToken.Dispose();

src/WebJobs.Script/Rpc/FunctionRegistration/IFunctionDispatcher.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public interface IFunctionDispatcher : IDisposable
1515
// Tests if the function metadata is supported by a known language worker
1616
bool IsSupported(FunctionMetadata metadata, string language);
1717

18-
void Invoke(ScriptInvocationContext invocationContext);
18+
Task InvokeAsync(ScriptInvocationContext invocationContext);
1919

2020
Task InitializeAsync(IEnumerable<FunctionMetadata> functions);
2121
}

src/WebJobs.Script/Rpc/IWebHostLanguageWorkerChannelManager.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ public interface IWebHostLanguageWorkerChannelManager
1010
{
1111
Task<ILanguageWorkerChannel> InitializeChannelAsync(string language);
1212

13-
IEnumerable<ILanguageWorkerChannel> GetChannels(string language);
13+
Dictionary<string, TaskCompletionSource<ILanguageWorkerChannel>> GetChannels(string language);
1414

1515
Task SpecializeAsync();
1616

17-
bool ShutdownChannelIfExists(string language, string workerId);
17+
Task<bool> ShutdownChannelIfExistsAsync(string language, string workerId);
1818

19-
void ShutdownChannels();
19+
Task ShutdownChannelsAsync();
2020
}
2121
}

src/WebJobs.Script/Rpc/JobHostLanguageWorkerChannelManager.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@
44
using System;
55
using System.Collections.Concurrent;
66
using System.Collections.Generic;
7+
using Microsoft.Extensions.Logging;
78

89
namespace Microsoft.Azure.WebJobs.Script.Rpc
910
{
1011
internal class JobHostLanguageWorkerChannelManager : IJobHostLanguageWorkerChannelManager
1112
{
13+
private readonly ILogger _logger;
1214
private ConcurrentDictionary<string, ILanguageWorkerChannel> _channels = new ConcurrentDictionary<string, ILanguageWorkerChannel>();
1315

16+
public JobHostLanguageWorkerChannelManager(ILoggerFactory loggerFactory)
17+
{
18+
_logger = loggerFactory.CreateLogger<JobHostLanguageWorkerChannelManager>();
19+
}
20+
1421
public void AddChannel(ILanguageWorkerChannel channel)
1522
{
1623
_channels.TryAdd(channel.Id, channel);
@@ -20,6 +27,7 @@ public void DisposeAndRemoveChannel(ILanguageWorkerChannel channel)
2027
{
2128
if (_channels.TryRemove(channel.Id, out ILanguageWorkerChannel removedChannel))
2229
{
30+
_logger.LogDebug("Disposing language worker channel with id:{workerId}", removedChannel.Id);
2331
(removedChannel as IDisposable)?.Dispose();
2432
}
2533
}
@@ -30,6 +38,7 @@ public void DisposeAndRemoveChannels()
3038
{
3139
if (_channels.TryRemove(channelId, out ILanguageWorkerChannel removedChannel))
3240
{
41+
_logger.LogDebug("Disposing language worker channel with id:{workerId}", removedChannel.Id);
3342
(removedChannel as IDisposable)?.Dispose();
3443
}
3544
}

src/WebJobs.Script/Rpc/RpcInitializationService.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ public async Task StartAsync(CancellationToken cancellationToken)
6565
_logger.LogDebug("Rpc Initialization Service started.");
6666
}
6767

68-
public Task StopAsync(CancellationToken cancellationToken)
68+
public async Task StopAsync(CancellationToken cancellationToken)
6969
{
7070
_logger.LogDebug("Shuttingdown Rpc Channels Manager");
71-
_webHostlanguageWorkerChannelManager.ShutdownChannels();
72-
return Task.CompletedTask;
71+
await _webHostlanguageWorkerChannelManager.ShutdownChannelsAsync();
7372
}
7473

7574
public async Task OuterStopAsync(CancellationToken cancellationToken)

0 commit comments

Comments
 (0)