diff --git a/src/WebJobs.Script/Http/DefaultHttpProxyService.cs b/src/WebJobs.Script/Http/DefaultHttpProxyService.cs index d7e6828024..7466147372 100644 --- a/src/WebJobs.Script/Http/DefaultHttpProxyService.cs +++ b/src/WebJobs.Script/Http/DefaultHttpProxyService.cs @@ -21,6 +21,7 @@ internal class DefaultHttpProxyService : IHttpProxyService, IDisposable private readonly HttpMessageInvoker _messageInvoker; private readonly ForwarderRequestConfig _forwarderRequestConfig; private readonly ILogger _logger; + private readonly HttpTransformer _httpTransformer; public DefaultHttpProxyService(IHttpForwarder httpForwarder, ILogger logger) { @@ -39,6 +40,8 @@ public DefaultHttpProxyService(IHttpForwarder httpForwarder, ILogger SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { + TaskCompletionSource resultSource = null; + if (request.Options.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out ScriptInvocationContext scriptInvocationContext)) + { + resultSource = scriptInvocationContext.ResultSource; + } + var currentDelay = InitialDelay; for (int attemptCount = 1; attemptCount <= MaxRetries; attemptCount++) { try { + if (resultSource is not null && resultSource.Task.IsFaulted) + { + throw resultSource.Task.Exception.InnerException ?? + new Exception($"The function invocation tied to this HTTP request failed. Invocation ID: {scriptInvocationContext.ExecutionContext.InvocationId}"); + } + return await base.SendAsync(request, cancellationToken); } catch (TaskCanceledException) when (cancellationToken.IsCancellationRequested) diff --git a/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs b/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs new file mode 100644 index 0000000000..5e2e9297e1 --- /dev/null +++ b/src/WebJobs.Script/Http/ScriptInvocationRequestTransformer.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Script.Description; +using Yarp.ReverseProxy.Forwarder; + +namespace Microsoft.Azure.WebJobs.Script.Http +{ + internal sealed class ScriptInvocationRequestTransformer : HttpTransformer + { + public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken) + { + // this preserves previous behavior (which called the default transformer) - base method is also called inside of here + await HttpTransformer.Default.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken); + + if (httpContext.Items.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out object result) + && result is ScriptInvocationContext scriptContext) + { + proxyRequest.Options.TryAdd(ScriptConstants.HttpProxyScriptInvocationContext, scriptContext); + } + } + } +} \ No newline at end of file diff --git a/src/WebJobs.Script/ScriptConstants.cs b/src/WebJobs.Script/ScriptConstants.cs index 857c5bd912..9946042e30 100644 --- a/src/WebJobs.Script/ScriptConstants.cs +++ b/src/WebJobs.Script/ScriptConstants.cs @@ -261,6 +261,7 @@ public static class ScriptConstants public static readonly string HttpProxyingEnabled = "HttpProxyingEnabled"; public static readonly string HttpProxyCorrelationHeader = "x-ms-invocation-id"; public static readonly string HttpProxyTask = "HttpProxyTask"; + public static readonly string HttpProxyScriptInvocationContext = "HttpProxyScriptInvocationContext"; public static readonly string OperationNameKey = "OperationName"; diff --git a/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs b/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs index 81313edb38..b265a7374c 100644 --- a/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs +++ b/test/WebJobs.Script.Tests/HttpProxyService/DefaultHttpProxyServiceTests.cs @@ -12,7 +12,7 @@ using Xunit; using Yarp.ReverseProxy.Forwarder; -namespace Microsoft.Azure.WebJobs.Script.Tests +namespace Microsoft.Azure.WebJobs.Script.Tests.Http { public class DefaultHttpProxyServiceTests { diff --git a/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs b/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs new file mode 100644 index 0000000000..d6e1f5e1fe --- /dev/null +++ b/test/WebJobs.Script.Tests/HttpProxyService/ScriptInvocationRequestTransformerTests.cs @@ -0,0 +1,199 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Script.Description; +using Microsoft.Azure.WebJobs.Script.Http; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Script.Tests.Http +{ + public class ScriptInvocationRequestTransformerTests + { + private readonly ScriptInvocationRequestTransformer _transformer; + + public ScriptInvocationRequestTransformerTests() + { + _transformer = new ScriptInvocationRequestTransformer(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task TransformRequestAsync_IncludesXForwardedHeaders(bool includeScriptInvocationContext) + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Host = new HostString("example.com", 443); + httpContext.Request.PathBase = "/api"; + httpContext.Request.Path = "/test"; + httpContext.Request.QueryString = new QueryString("?param=value"); + + var remoteAddress = "192.168.1.100"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(remoteAddress); + + if (includeScriptInvocationContext) + { + var scriptContext = new ScriptInvocationContext + { + FunctionMetadata = new FunctionMetadata { Name = "TestFunction" }, + ExecutionContext = new ExecutionContext { InvocationId = Guid.NewGuid() } + }; + + httpContext.Items[ScriptConstants.HttpProxyScriptInvocationContext] = scriptContext; + } + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For"), "X-Forwarded-For header should be present"); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host"), "X-Forwarded-Host header should be present"); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto"), "X-Forwarded-Proto header should be present"); + + var forwardedFor = proxyRequest.Headers.GetValues("X-Forwarded-For"); + Assert.Contains(remoteAddress, forwardedFor); + + var forwardedHost = proxyRequest.Headers.GetValues("X-Forwarded-Host"); + Assert.Contains("example.com:443", forwardedHost); + + var forwardedProto = proxyRequest.Headers.GetValues("X-Forwarded-Proto"); + Assert.Contains("https", forwardedProto); + } + + [Fact] + public async Task TransformRequestAsync_WithScriptInvocationContext_AddsContextToRequestOptions() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "http"; + httpContext.Request.Host = new HostString("localhost", 7071); + var remoteAddress = "192.168.1.100"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(remoteAddress); + + var scriptContext = new ScriptInvocationContext + { + FunctionMetadata = new FunctionMetadata { Name = "TestFunction" }, + ExecutionContext = new ExecutionContext { InvocationId = Guid.NewGuid() } + }; + + httpContext.Items[ScriptConstants.HttpProxyScriptInvocationContext] = scriptContext; + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Options.TryGetValue(ScriptConstants.HttpProxyScriptInvocationContext, out ScriptInvocationContext contextValue)); + Assert.Equal(scriptContext.ExecutionContext.InvocationId, contextValue.ExecutionContext.InvocationId); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto")); + } + + [Fact] + public async Task TransformRequestAsync_PreservesExistingXForwardedHeaders() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Host = new HostString("proxy.example.com"); + var requestRemoteAddress = "172.16.0.1"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(requestRemoteAddress); + + // Add existing X-Forwarded headers to simulate request through multiple proxies + var originalFor = "203.0.113.195," + requestRemoteAddress; + var originalHost = "proxy.example.com"; + var originalProto = "https"; + httpContext.Request.Headers["X-Forwarded-For"] = originalFor; + httpContext.Request.Headers["X-Forwarded-Host"] = originalHost; + httpContext.Request.Headers["X-Forwarded-Proto"] = originalProto; + + var proxyRequest = new HttpRequestMessage(HttpMethod.Get, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto")); + + var forwardedFor = proxyRequest.Headers.GetValues("X-Forwarded-For"); + Assert.Contains(requestRemoteAddress, forwardedFor); + + var forwardedHost = proxyRequest.Headers.GetValues("X-Forwarded-Host"); + Assert.Contains(originalHost, forwardedHost); + + var forwardedProto = proxyRequest.Headers.GetValues("X-Forwarded-Proto"); + Assert.Contains(originalProto, forwardedProto); + } + + [Fact] + public async Task TransformRequestAsync_PreservesStandardRequestHeaders() + { + var httpContext = new DefaultHttpContext(); + httpContext.Request.Scheme = "https"; + httpContext.Request.Host = new HostString("example.com"); + httpContext.Request.Path = "/api/test"; + var remoteAddress = "192.168.1.100"; + httpContext.Connection.RemoteIpAddress = System.Net.IPAddress.Parse(remoteAddress); + + // Add various standard headers that should be preserved + httpContext.Request.Headers["Authorization"] = "Bearer token123"; + httpContext.Request.Headers["User-Agent"] = "TestClient/1.0"; + httpContext.Request.Headers["Accept"] = "application/json"; + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["X-Custom-Header"] = "custom-value"; + httpContext.Request.Headers["Cache-Control"] = "no-cache"; + httpContext.Request.Headers["Accept-Encoding"] = "gzip, deflate"; + + var proxyRequest = new HttpRequestMessage(HttpMethod.Post, "http://localhost:7071/api/test"); + const string destinationPrefix = "http://localhost:7071"; + + await _transformer.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); + + // Verify that standard headers are preserved + Assert.True(proxyRequest.Headers.Contains("Authorization"), "Authorization header should be preserved"); + Assert.True(proxyRequest.Headers.Contains("User-Agent"), "User-Agent header should be preserved"); + Assert.True(proxyRequest.Headers.Contains("Accept"), "Accept header should be preserved"); + Assert.True(proxyRequest.Headers.Contains("X-Custom-Header"), "Custom headers should be preserved"); + Assert.True(proxyRequest.Headers.Contains("Cache-Control"), "Cache-Control header should be preserved"); + Assert.True(proxyRequest.Headers.Contains("Accept-Encoding"), "Accept-Encoding header should be preserved"); + + // Verify header values + var authHeader = proxyRequest.Headers.GetValues("Authorization"); + Assert.Contains("Bearer token123", authHeader); + + var userAgentHeader = proxyRequest.Headers.GetValues("User-Agent"); + Assert.Contains("TestClient/1.0", userAgentHeader); + + var acceptHeader = proxyRequest.Headers.GetValues("Accept"); + Assert.Contains("application/json", acceptHeader); + + var customHeader = proxyRequest.Headers.GetValues("X-Custom-Header"); + Assert.Contains("custom-value", customHeader); + + var cacheControlHeader = proxyRequest.Headers.GetValues("Cache-Control"); + Assert.Contains("no-cache", cacheControlHeader); + + var acceptEncodingHeader = proxyRequest.Headers.GetValues("Accept-Encoding"); + Assert.Contains("gzip", acceptEncodingHeader); + Assert.Contains("deflate", acceptEncodingHeader); + + // Verify that Content-Type is properly handled for the request content + if (proxyRequest.Content != null) + { + Assert.Equal("application/json", proxyRequest.Content.Headers.ContentType?.MediaType); + } + + // Also verify X-Forwarded headers are still added + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-For")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Host")); + Assert.True(proxyRequest.Headers.Contains("X-Forwarded-Proto")); + } + } +} \ No newline at end of file diff --git a/test/WebJobs.Script.Tests/Workers/RetryProxyHandlerTests.cs b/test/WebJobs.Script.Tests/Workers/RetryProxyHandlerTests.cs index 1905c27423..82039f1004 100644 --- a/test/WebJobs.Script.Tests/Workers/RetryProxyHandlerTests.cs +++ b/test/WebJobs.Script.Tests/Workers/RetryProxyHandlerTests.cs @@ -1,10 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; using System.Net.Http; using System.Reflection; using System.Threading; using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Http; using Microsoft.Extensions.Logging.Abstractions; using Xunit; @@ -32,6 +35,97 @@ public async Task SendAsync_RetriesToMax() Assert.Equal(RetryProxyHandler.MaxRetries, inner.Attempts); } + [Fact] + public async Task SendAsync_StopsRetriesWhenScriptInvocationResultIsFaulted() + { + var inner = new TestHandler(); + var handler = new RetryProxyHandler(inner, NullLogger.Instance); + var request = new HttpRequestMessage(); + + // Create a faulted TaskCompletionSource for ScriptInvocationResult + var faultedResultSource = new TaskCompletionSource(); + var invocationException = new InvalidOperationException("Function invocation failed"); + faultedResultSource.SetException(invocationException); + + // Create ScriptInvocationContext with faulted result source + var scriptInvocationContext = new ScriptInvocationContext + { + ExecutionContext = new ExecutionContext + { + InvocationId = Guid.NewGuid() + }, + ResultSource = faultedResultSource + }; + + // Add the context to the request options + request.Options.TryAdd(ScriptConstants.HttpProxyScriptInvocationContext, scriptInvocationContext); + + var response = typeof(RetryProxyHandler)! + .GetMethod("SendAsync", BindingFlags.NonPublic | BindingFlags.Instance)! + .Invoke(handler, new object[] { request, CancellationToken.None }) + as Task; + + var result = await response.ContinueWith(t => t); + + // Verify that the task is faulted due to the ScriptInvocationResult being faulted + Assert.True(result.IsFaulted); + Assert.True(result.Exception.InnerException is HttpRequestException); + Assert.Contains("The function invocation tied to this HTTP request failed", result.Exception.InnerException.Message); + Assert.Contains(scriptInvocationContext.ExecutionContext.InvocationId.ToString(), result.Exception.InnerException.Message); + + // Verify that the inner exception contains the original invocation exception + Assert.NotNull(result.Exception.InnerException.InnerException); + Assert.Contains(invocationException, GetAllInnerExceptions(result.Exception.InnerException.InnerException)); + + // Verify that no retries were attempted since the result source was already faulted + Assert.Equal(0, inner.Attempts); + } + + [Fact] + public async Task SendAsync_RetriesNormallyWhenScriptInvocationResultIsNotFaulted() + { + var inner = new TestHandler(); + var handler = new RetryProxyHandler(inner, NullLogger.Instance); + var request = new HttpRequestMessage(); + + // Create a non-faulted TaskCompletionSource for ScriptInvocationResult + var resultSource = new TaskCompletionSource(); + + // Create ScriptInvocationContext with non-faulted result source + var scriptInvocationContext = new ScriptInvocationContext + { + ExecutionContext = new ExecutionContext + { + InvocationId = Guid.NewGuid() + }, + ResultSource = resultSource + }; + + // Add the context to the request options + request.Options.TryAdd(ScriptConstants.HttpProxyScriptInvocationContext, scriptInvocationContext); + + var response = typeof(RetryProxyHandler)! + .GetMethod("SendAsync", BindingFlags.NonPublic | BindingFlags.Instance)! + .Invoke(handler, new object[] { request, CancellationToken.None }) + as Task; + + var result = await response.ContinueWith(t => t); + + // Verify that retries occurred normally since the result source was not faulted + Assert.True(result.IsFaulted); + Assert.True(result.Exception.InnerException is HttpRequestException); + Assert.Equal(RetryProxyHandler.MaxRetries, inner.Attempts); + } + + private static System.Collections.Generic.IEnumerable GetAllInnerExceptions(Exception exception) + { + while (exception != null) + { + yield return exception; + exception = exception.InnerException; + } + } + private class TestHandler : HttpMessageHandler { public int Attempts { get; set; }