diff --git a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs index 7b897bfea0..fe8927f54f 100644 --- a/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs +++ b/src/WebJobs.Script.Grpc/Channel/GrpcWorkerChannel.cs @@ -16,12 +16,14 @@ using System.Threading.Tasks.Dataflow; using Google.Protobuf.Collections; using Google.Protobuf.WellKnownTypes; +using Microsoft.Azure.WebJobs.Host; using Microsoft.Azure.WebJobs.Logging; using Microsoft.Azure.WebJobs.Script.Config; using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Diagnostics; using Microsoft.Azure.WebJobs.Script.Diagnostics.OpenTelemetry; using Microsoft.Azure.WebJobs.Script.Eventing; +using Microsoft.Azure.WebJobs.Script.Exceptions; using Microsoft.Azure.WebJobs.Script.Extensions; using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Grpc.Extensions; @@ -1555,21 +1557,24 @@ public bool IsExecutingInvocation(string invocationId) return _executingInvocations.ContainsKey(invocationId); } - public bool TryFailExecutions(Exception workerException) + public void Shutdown(Exception workerException) { - if (workerException == null) + var shutdownException = workerException; + + if (workerException is null || workerException is FunctionTimeoutException) { - return false; + shutdownException = new FunctionAbortedException(workerException?.Message ?? "Worker channel is shutting down. Aborting function.", workerException); } foreach (var invocation in _executingInvocations?.Values) { string invocationId = invocation.Context?.ExecutionContext?.InvocationId.ToString(); _workerChannelLogger.LogDebug("Worker '{workerId}' encountered a fatal error. Failing invocation: '{invocationId}'", _workerId, invocationId); - invocation.Context?.ResultSource?.TrySetException(workerException); + + invocation.Context?.ResultSource?.TrySetException(shutdownException); + RemoveExecutingInvocation(invocationId); } - return true; } /// diff --git a/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs b/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs index cd8b364caf..0708764230 100644 --- a/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs +++ b/src/WebJobs.Script.WebHost/WebScriptHostExceptionHandler.cs @@ -27,40 +27,46 @@ public WebScriptHostExceptionHandler(IApplicationLifetime applicationLifetime, I public async Task OnTimeoutExceptionAsync(ExceptionDispatchInfo exceptionInfo, TimeSpan timeoutGracePeriod) { - FunctionTimeoutException timeoutException = exceptionInfo.SourceException as FunctionTimeoutException; - - if (timeoutException?.Task != null) + if (exceptionInfo.SourceException is FunctionTimeoutException timeoutException) { - // We may double the timeoutGracePeriod here by first waiting to see if the initial - // function task that started the exception has completed. - Task completedTask = await Task.WhenAny(timeoutException.Task, Task.Delay(timeoutGracePeriod)); - - // If the function task has completed, simply return. The host has already logged the timeout. - if (completedTask == timeoutException.Task) + if (timeoutException?.Task != null) { - return; + // We may double the timeoutGracePeriod here by first waiting to see if the initial + // function task that started the exception has completed. + Task completedTask = await Task.WhenAny(timeoutException.Task, Task.Delay(timeoutGracePeriod)); + + // If the function task has completed, simply return. The host has already logged the timeout. + if (completedTask == timeoutException.Task) + { + return; + } } - } - // We can't wait on this as it may cause a deadlock if the timeout was fired - // by a Listener that cannot stop until it has completed. - // TODO: DI (FACAVAL) The shutdown call will invoke the host stop... but we may need to do this - // explicitly in order to pass the timeout. - // Task ignoreTask = _hostManager.StopAsync(); - // Give the manager and all running tasks some time to shut down gracefully. - //await Task.Delay(timeoutGracePeriod); - IFunctionInvocationDispatcher functionInvocationDispatcher = _functionInvocationDispatcherFactory.GetFunctionDispatcher(); - if (!functionInvocationDispatcher.State.Equals(FunctionInvocationDispatcherState.Default)) - { - _logger.LogWarning($"A function timeout has occurred. Restarting worker process executing invocationId '{timeoutException.InstanceId}'.", exceptionInfo.SourceException); - // If invocation id is not found in any of the workers => worker is already disposed. No action needed. - await functionInvocationDispatcher.RestartWorkerWithInvocationIdAsync(timeoutException.InstanceId.ToString()); - _logger.LogWarning("Restart of language worker process(es) completed.", exceptionInfo.SourceException); + // We can't wait on this as it may cause a deadlock if the timeout was fired + // by a Listener that cannot stop until it has completed. + // TODO: DI (FACAVAL) The shutdown call will invoke the host stop... but we may need to do this + // explicitly in order to pass the timeout. + // Task ignoreTask = _hostManager.StopAsync(); + // Give the manager and all running tasks some time to shut down gracefully. + // await Task.Delay(timeoutGracePeriod); + IFunctionInvocationDispatcher functionInvocationDispatcher = _functionInvocationDispatcherFactory.GetFunctionDispatcher(); + if (!functionInvocationDispatcher.State.Equals(FunctionInvocationDispatcherState.Default)) + { + _logger.LogWarning($"A function timeout has occurred. Restarting worker process executing invocationId '{timeoutException.InstanceId}'.", exceptionInfo.SourceException); + // If invocation id is not found in any of the workers => worker is already disposed. No action needed. + await functionInvocationDispatcher.RestartWorkerWithInvocationIdAsync(timeoutException.InstanceId.ToString(), timeoutException); + _logger.LogWarning("Restart of language worker process(es) completed.", exceptionInfo.SourceException); + } + else + { + LogErrorAndFlush("A function timeout has occurred. Host is shutting down.", exceptionInfo.SourceException); + _applicationLifetime.StopApplication(); + } } else { - LogErrorAndFlush("A function timeout has occurred. Host is shutting down.", exceptionInfo.SourceException); - _applicationLifetime.StopApplication(); + // still testing why this can occur, leaving a placeholder log message for now. + LogErrorAndFlush("An unexpected timeout exception has occurred. Host is shutting down.", exceptionInfo.SourceException); } } diff --git a/src/WebJobs.Script/Exceptions/FunctionAbortedException.cs b/src/WebJobs.Script/Exceptions/FunctionAbortedException.cs new file mode 100644 index 0000000000..950954028d --- /dev/null +++ b/src/WebJobs.Script/Exceptions/FunctionAbortedException.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using Microsoft.Azure.WebJobs.Host; + +namespace Microsoft.Azure.WebJobs.Script.Exceptions +{ + internal sealed class FunctionAbortedException : FunctionTimeoutException + { + public FunctionAbortedException() { } + + public FunctionAbortedException(string message) : base(message) { } + + public FunctionAbortedException(string message, Exception innerException) : base(message, innerException) + { + } + } +} diff --git a/src/WebJobs.Script/Http/DefaultHttpProxyService.cs b/src/WebJobs.Script/Http/DefaultHttpProxyService.cs index d7e6828024..5509743692 100644 --- a/src/WebJobs.Script/Http/DefaultHttpProxyService.cs +++ b/src/WebJobs.Script/Http/DefaultHttpProxyService.cs @@ -21,6 +21,7 @@ internal class DefaultHttpProxyService : IHttpProxyService, IDisposable private readonly HttpMessageInvoker _messageInvoker; private readonly ForwarderRequestConfig _forwarderRequestConfig; private readonly ILogger _logger; + private readonly HttpTransformer _httpTransformer; public DefaultHttpProxyService(IHttpForwarder httpForwarder, ILogger logger) { @@ -39,6 +40,8 @@ public DefaultHttpProxyService(IHttpForwarder httpForwarder, ILogger SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + TaskCompletionSource resultSource = null; + if (request.Options.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out ScriptInvocationContext scriptInvocationContext)) + { + resultSource = scriptInvocationContext.ResultSource; + } + var currentDelay = InitialDelay; for (int attemptCount = 1; attemptCount <= MaxRetries; attemptCount++) { try { + if (resultSource is not null && resultSource.Task.IsFaulted) + { + throw resultSource.Task.Exception?.InnerException ?? new HttpRequestException("The function invocation tied to this HTTP request failed."); + } + return await base.SendAsync(request, cancellationToken); } catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) @@ -51,6 +64,11 @@ protected override async Task SendAsync(HttpRequestMessage currentDelay = Math.Min(currentDelay * 2, MaximumDelay); } + catch (FunctionAbortedException) + { + _logger.LogDebug("Function invocation aborted. Request will not be retried."); + throw; + } catch (Exception ex) { var message = attemptCount == MaxRetries diff --git a/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs b/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs new file mode 100644 index 0000000000..c2325e3075 --- /dev/null +++ b/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Script.Description; +using Yarp.ReverseProxy.Forwarder; + +namespace Microsoft.Azure.WebJobs.Script.Http +{ + internal class ScriptInvocationRequestTransformer : HttpTransformer + { + public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken) + { + // this preserves previous behavior (which called the default transformer) - base method is also called inside of here + await HttpTransformer.Default.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken); + + if (httpContext.Items.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out object result) + && result is ScriptInvocationContext scriptContext) + { + proxyRequest.Options.TryAdd(ScriptConstants.HttpProxyScriptInvocationContext, scriptContext); + } + } + } +} \ No newline at end of file diff --git a/src/WebJobs.Script/ScriptConstants.cs b/src/WebJobs.Script/ScriptConstants.cs index 3ce16d0575..31bafefa69 100644 --- a/src/WebJobs.Script/ScriptConstants.cs +++ b/src/WebJobs.Script/ScriptConstants.cs @@ -261,6 +261,7 @@ public static class ScriptConstants public static readonly string HttpProxyingEnabled = "HttpProxyingEnabled"; public static readonly string HttpProxyCorrelationHeader = "x-ms-invocation-id"; public static readonly string HttpProxyTask = "HttpProxyTask"; + public static readonly string HttpProxyScriptInvocationContext = "HttpProxyScriptInvocationContext"; public static readonly string OperationNameKey = "OperationName"; diff --git a/src/WebJobs.Script/Workers/Http/HttpFunctionInvocationDispatcher.cs b/src/WebJobs.Script/Workers/Http/HttpFunctionInvocationDispatcher.cs index f76eed4f5a..902a85255a 100644 --- a/src/WebJobs.Script/Workers/Http/HttpFunctionInvocationDispatcher.cs +++ b/src/WebJobs.Script/Workers/Http/HttpFunctionInvocationDispatcher.cs @@ -206,7 +206,7 @@ public Task ShutdownAsync() return Task.CompletedTask; } - public Task RestartWorkerWithInvocationIdAsync(string invocationId) + public Task RestartWorkerWithInvocationIdAsync(string invocationId, Exception exception = null) { // Since there's only one channel for httpworker DisposeAndRestartWorkerChannel(_httpWorkerChannel.Id); diff --git a/src/WebJobs.Script/Workers/IFunctionInvocationDispatcher.cs b/src/WebJobs.Script/Workers/IFunctionInvocationDispatcher.cs index ee40498bf4..b63e6f84c9 100644 --- a/src/WebJobs.Script/Workers/IFunctionInvocationDispatcher.cs +++ b/src/WebJobs.Script/Workers/IFunctionInvocationDispatcher.cs @@ -23,7 +23,7 @@ public interface IFunctionInvocationDispatcher : IDisposable Task ShutdownAsync(); - Task RestartWorkerWithInvocationIdAsync(string invocationId); + Task RestartWorkerWithInvocationIdAsync(string invocationId, Exception exception); Task StartWorkerChannel(); diff --git a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs index 5025c58591..75ff359af5 100644 --- a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs +++ b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs @@ -676,7 +676,7 @@ public void Dispose() Dispose(true); } - public async Task RestartWorkerWithInvocationIdAsync(string invocationId) + public async Task RestartWorkerWithInvocationIdAsync(string invocationId, Exception exception = null) { // Dispose and restart errored channel with the particular invocation id var channels = await GetInitializedWorkerChannelsAsync(); @@ -685,7 +685,7 @@ public async Task RestartWorkerWithInvocationIdAsync(string invocationId) if (channel.IsExecutingInvocation(invocationId)) { _logger.LogDebug($"Restarting channel with workerId: '{channel.Id}' that is executing invocation: '{invocationId}' and timed out."); - await DisposeAndRestartWorkerChannel(_workerRuntime, channel.Id); + await DisposeAndRestartWorkerChannel(_workerRuntime, channel.Id, exception); return true; } } diff --git a/src/WebJobs.Script/Workers/Rpc/IRpcWorkerChannel.cs b/src/WebJobs.Script/Workers/Rpc/IRpcWorkerChannel.cs index d9d3e6eb9c..f187eae71f 100644 --- a/src/WebJobs.Script/Workers/Rpc/IRpcWorkerChannel.cs +++ b/src/WebJobs.Script/Workers/Rpc/IRpcWorkerChannel.cs @@ -32,6 +32,6 @@ public interface IRpcWorkerChannel : IWorkerChannel bool IsExecutingInvocation(string invocationId); - bool TryFailExecutions(Exception workerException); + void Shutdown(Exception workerException); } } diff --git a/src/WebJobs.Script/Workers/Rpc/JobHostRpcWorkerChannelManager.cs b/src/WebJobs.Script/Workers/Rpc/JobHostRpcWorkerChannelManager.cs index e326da8c68..a8a9be2cbd 100644 --- a/src/WebJobs.Script/Workers/Rpc/JobHostRpcWorkerChannelManager.cs +++ b/src/WebJobs.Script/Workers/Rpc/JobHostRpcWorkerChannelManager.cs @@ -49,7 +49,7 @@ public Task ShutdownChannelIfExistsAsync(string channelId, Exception worke { string id = rpcChannel.Id; _logger.LogDebug("Disposing language worker channel with id:{workerId}", id); - rpcChannel.TryFailExecutions(workerException); + rpcChannel.Shutdown(workerException); (rpcChannel as IDisposable)?.Dispose(); _logger.LogDebug("Disposed language worker channel with id:{workerId}", id); diff --git a/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs b/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs index 3dd214d343..06786f4931 100644 --- a/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs +++ b/src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs @@ -269,7 +269,7 @@ public Task ShutdownChannelIfExistsAsync(string language, string workerId, if (workerChannel != null) { _logger.LogDebug("Disposing WebHost channel for workerId: {channelId}, for runtime:{language}", workerId, language); - workerChannel.TryFailExecutions(workerException); + workerChannel.Shutdown(workerException); (channelTask.Result as IDisposable)?.Dispose(); } } @@ -295,7 +295,7 @@ public Task ShutdownChannelIfExistsAsync(string language, string workerId, if (workerChannel != null) { _logger.LogDebug("Disposing WebHost channel for workerId: {channelId}, for runtime:{language}", workerId, language); - workerChannel.TryFailExecutions(workerException); + workerChannel.Shutdown(workerException); (channelTask.Result as IDisposable)?.Dispose(); } } diff --git a/test/WebJobs.Script.Tests/Handlers/WebScriptHostExceptionHandlerTests.cs b/test/WebJobs.Script.Tests/Handlers/WebScriptHostExceptionHandlerTests.cs new file mode 100644 index 0000000000..45617c0e3d --- /dev/null +++ b/test/WebJobs.Script.Tests/Handlers/WebScriptHostExceptionHandlerTests.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Runtime.ExceptionServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Azure.WebJobs.Host; +using Microsoft.Azure.WebJobs.Script.WebHost; +using Microsoft.Azure.WebJobs.Script.Workers; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Script.Tests.Handlers +{ + public class WebScriptHostExceptionHandlerTests + { + private readonly Mock _mockApplicationLifetime; + private readonly Mock> _mockLogger; + private readonly Mock _mockDispatcherFactory; + private readonly Mock _mockDispatcher; + private readonly WebScriptHostExceptionHandler _exceptionHandler; + + public WebScriptHostExceptionHandlerTests() + { + _mockApplicationLifetime = new Mock(); + _mockLogger = new Mock>(); + _mockDispatcherFactory = new Mock(); + _mockDispatcher = new Mock(); + + _mockDispatcherFactory.Setup(f => f.GetFunctionDispatcher()) + .Returns(_mockDispatcher.Object); + + _exceptionHandler = new WebScriptHostExceptionHandler( + _mockApplicationLifetime.Object, + _mockLogger.Object, + _mockDispatcherFactory.Object); + } + + [Fact] + public async Task OnTimeoutExceptionAsync_CallsRestartWorkerWithInvocationIdAsync_WithTimeoutException() + { + var task = Task.CompletedTask; + var timeoutException = new FunctionTimeoutException("Test timeout"); + var exceptionInfo = ExceptionDispatchInfo.Capture(timeoutException); + var timeoutGracePeriod = TimeSpan.FromSeconds(5); + + _mockDispatcher.Setup(d => d.State) + .Returns(FunctionInvocationDispatcherState.Initialized); + _mockDispatcher.Setup(d => d.RestartWorkerWithInvocationIdAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(true)); + + await _exceptionHandler.OnTimeoutExceptionAsync(exceptionInfo, timeoutGracePeriod); + + _mockDispatcher.Verify(d => d.RestartWorkerWithInvocationIdAsync( + It.IsAny(), + timeoutException), Times.Once); + } + + [Fact] + public async Task OnTimeoutExceptionAsync_WhenTaskDoesNotCompleteWithinGracePeriod_RestartsWorker() + { + // Arrange + var invocationId = Guid.NewGuid(); + var taskCompletionSource = new TaskCompletionSource(); + var timeoutException = new FunctionTimeoutException("Test timeout"); + var exceptionInfo = ExceptionDispatchInfo.Capture(timeoutException); + var timeoutGracePeriod = TimeSpan.FromMilliseconds(100); // Short grace period + + _mockDispatcher.Setup(d => d.State) + .Returns(FunctionInvocationDispatcherState.Initialized); + _mockDispatcher.Setup(d => d.RestartWorkerWithInvocationIdAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(true)); + + // Don't complete the task to simulate it not finishing within the grace period + + // Act + await _exceptionHandler.OnTimeoutExceptionAsync(exceptionInfo, timeoutGracePeriod); + + // Assert + _mockDispatcher.Verify(d => d.RestartWorkerWithInvocationIdAsync( + It.IsAny(), + timeoutException), Times.Once); + } + } +} \ No newline at end of file diff --git a/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs b/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs index 81313edb38..b265a7374c 100644 --- a/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs +++ b/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs @@ -12,7 +12,7 @@ using Xunit; using Yarp.ReverseProxy.Forwarder; -namespace Microsoft.Azure.WebJobs.Script.Tests +namespace Microsoft.Azure.WebJobs.Script.Tests.Http { public class DefaultHttpProxyServiceTests { diff --git a/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs b/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs new file mode 100644 index 0000000000..5152e0050d --- /dev/null +++ b/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs @@ -0,0 +1,135 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Script.Description; +using Microsoft.Azure.WebJobs.Script.Http; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Script.Tests.Http +{ + public class ScriptInvocationRequestTransformerTests + { + private readonly ScriptInvocationRequestTransformer _transformer; + + public ScriptInvocationRequestTransformerTests() + { + _transformer = new ScriptInvocationRequestTransformer(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task TransformRequestAsync_IncludesXForwardedHeaders(bool includeScriptInvocationContext) + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Host = new HostString("example.com", 443); + httpContext.Request.PathBase = "/api"; + httpContext.Request.Path = "/test"; + httpContext.Request.QueryString = new QueryString("?param=value"); + + var remoteAddress = "192.168.1.100"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(remoteAddress); + + if (includeScriptInvocationContext) + { + var scriptContext = new ScriptInvocationContext + { + FunctionMetadata = new FunctionMetadata { Name = "TestFunction" }, + ExecutionContext = new ExecutionContext { InvocationId = Guid.NewGuid() } + }; + + httpContext.Items[ScriptConstants.HttpProxyScriptInvocationContext] = scriptContext; + } + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For"), "X-Forwarded-For header should be present"); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host"), "X-Forwarded-Host header should be present"); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto"), "X-Forwarded-Proto header should be present"); + + var forwardedFor = proxyRequest.Headers.GetValues("X-Forwarded-For"); + Assert.Contains(remoteAddress, forwardedFor); + + var forwardedHost = proxyRequest.Headers.GetValues("X-Forwarded-Host"); + Assert.Contains("example.com:443", forwardedHost); + + var forwardedProto = proxyRequest.Headers.GetValues("X-Forwarded-Proto"); + Assert.Contains("https", forwardedProto); + } + + [Fact] + public async Task TransformRequestAsync_WithScriptInvocationContext_AddsContextToRequestOptions() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Host = new HostString("localhost", 7071); + var remoteAddress = "192.168.1.100"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(remoteAddress); + + var scriptContext = new ScriptInvocationContext + { + FunctionMetadata = new FunctionMetadata { Name = "TestFunction" }, + ExecutionContext = new ExecutionContext { InvocationId = Guid.NewGuid() } + }; + + httpContext.Items[ScriptConstants.HttpProxyScriptInvocationContext] = scriptContext; + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Options.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out ScriptInvocationContext contextValue)); + Assert.Equal(scriptContext.ExecutionContext.InvocationId, contextValue.ExecutionContext.InvocationId); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto")); + } + + [Fact] + public async Task TransformRequestAsync_PreservesExistingXForwardedHeaders() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Host = new HostString("proxy.example.com"); + var requestRemoteAddress = "172.16.0.1"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(requestRemoteAddress); + + // Add existing X-Forwarded headers to simulate request through multiple proxies + var originalFor = "203.0.113.195," + requestRemoteAddress; + var originalHost = "proxy.example.com"; + var originalProto = "https"; + httpContext.Request.Headers["X-Forwarded-For"] = originalFor; + httpContext.Request.Headers["X-Forwarded-Host"] = originalHost; + httpContext.Request.Headers["X-Forwarded-Proto"] = originalProto; + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto")); + + var forwardedFor = proxyRequest.Headers.GetValues("X-Forwarded-For"); + Assert.Contains(requestRemoteAddress, forwardedFor); + + var forwardedHost = proxyRequest.Headers.GetValues("X-Forwarded-Host"); + Assert.Contains(originalHost, forwardedHost); + + var forwardedProto = proxyRequest.Headers.GetValues("X-Forwarded-Proto"); + Assert.Contains(originalProto, forwardedProto); + } + } +} \ No newline at end of file diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs index 767b9e232d..61ae5651d2 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/GrpcWorkerChannelTests.cs @@ -14,6 +14,7 @@ using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Diagnostics; using Microsoft.Azure.WebJobs.Script.Eventing; +using Microsoft.Azure.WebJobs.Script.Exceptions; using Microsoft.Azure.WebJobs.Script.Grpc; using Microsoft.Azure.WebJobs.Script.Grpc.Eventing; using Microsoft.Azure.WebJobs.Script.Grpc.Messages; @@ -572,21 +573,6 @@ await channel.InvokeResponse(new InvocationResponse Assert.Equal(result.Status, TaskStatus.RanToCompletion); } - [Fact] - public async Task InFlight_Functions_FailedWithException() - { - await CreateDefaultWorkerChannel(); - var resultSource = new TaskCompletionSource(); - ScriptInvocationContext scriptInvocationContext = GetTestScriptInvocationContext(Guid.NewGuid(), resultSource); - await _workerChannel.SendInvocationRequest(scriptInvocationContext); - Assert.True(_workerChannel.IsExecutingInvocation(scriptInvocationContext.ExecutionContext.InvocationId.ToString())); - Exception workerException = new Exception("worker failed"); - _workerChannel.TryFailExecutions(workerException); - Assert.False(_workerChannel.IsExecutingInvocation(scriptInvocationContext.ExecutionContext.InvocationId.ToString())); - Assert.Equal(TaskStatus.Faulted, resultSource.Task.Status); - Assert.Equal(workerException, resultSource.Task.Exception.InnerException); - } - [Fact] public async Task SendLoadRequests_PublishesOutboundEvents() { @@ -1561,6 +1547,66 @@ public async Task NullOutputBinding_DoesNotThrow() Assert.Equal(TaskStatus.RanToCompletion, resultSource.Task.Status); } + [Fact] + public async Task Shutdown_WithNoExecutingInvocations_DoesNotThrow() + { + await CreateDefaultWorkerChannel(); + var workerException = new Exception("Worker process crashed"); + + // Should not throw even if there are no executing invocations + _workerChannel.Shutdown(workerException); + } + + [Theory] + [InlineData(1, true)] + [InlineData(3, true)] + [InlineData(1, false)] + [InlineData(3, false)] + public async Task Shutdown_FailsInFlightInvocations(int numberOfInvocations, bool hasFailureException) + { + await CreateDefaultWorkerChannel(); + + var invocationContexts = new List(); + var invocationIds = new List(); + + for (int i = 0; i < numberOfInvocations; i++) + { + var invocationId = Guid.NewGuid(); + var resultSource = new TaskCompletionSource(); + + var invocationContext = GetTestScriptInvocationContext( + invocationId, + resultSource, + logger: _logger, + scriptRootPath: _scriptRootPath); + + await _workerChannel.SendInvocationRequest(invocationContext); + + invocationContexts.Add(invocationContext); + invocationIds.Add(invocationId); + } + + for (int i = 0; i < numberOfInvocations; i++) + { + Assert.True(_workerChannel.IsExecutingInvocation(invocationIds[i].ToString()), + $"Invocation {i} should be executing"); + } + + var workerException = hasFailureException ? new Exception("Worker process crashed") : null; + + _workerChannel.Shutdown(workerException); + + for (int i = 0; i < numberOfInvocations; i++) + { + Assert.False(_workerChannel.IsExecutingInvocation(invocationIds[i].ToString()), + $"Invocation {i} should no longer be executing"); + + var resultSource = invocationContexts[i].ResultSource; + Assert.Equal(TaskStatus.Faulted, resultSource.Task.Status); + Assert.IsType(resultSource.Task.Exception.InnerException); + } + } + private static IEnumerable GetTestFunctionsList(string runtime, bool addWorkerProperties = false) { return GetTestFunctionsList(runtime, numberOfFunctions: 2, addWorkerProperties); diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs index c1244380d4..dcda85f7b1 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs @@ -150,11 +150,10 @@ public bool IsExecutingInvocation(string invocationId) return _executingInvocations.Contains(invocationId); } - public bool TryFailExecutions(Exception exception) + public void Shutdown(Exception exception) { // Executions are no longer executing _executingInvocations = new HashSet(); - return true; } public void SendWorkerMetadataRequest() diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs index fe3a3187e9..de788f75b4 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannelManager.cs @@ -96,7 +96,7 @@ public async Task ShutdownChannelIfExistsAsync(string language, string wor IRpcWorkerChannel channel = await value?.Task; if (channel != null) { - channel.TryFailExecutions(workerException); + channel.Shutdown(workerException); (channel as IDisposable)?.Dispose(); rpcWorkerChannels.Remove(workerId); return true;