diff --git a/src/Http/Http.Extensions/src/DefaultProblemDetailsWriter.cs b/src/Http/Http.Extensions/src/DefaultProblemDetailsWriter.cs index 4d47a1e54f06..a71924da8e9c 100644 --- a/src/Http/Http.Extensions/src/DefaultProblemDetailsWriter.cs +++ b/src/Http/Http.Extensions/src/DefaultProblemDetailsWriter.cs @@ -56,15 +56,17 @@ public ValueTask WriteAsync(ProblemDetailsContext context) ProblemDetailsDefaults.Apply(context.ProblemDetails, httpContext.Response.StatusCode); var traceId = Activity.Current?.Id ?? httpContext.TraceIdentifier; - context.ProblemDetails.Extensions["traceId"] = traceId; + + var traceIdKeyName = _serializerOptions.PropertyNamingPolicy?.ConvertName("traceId") ?? "traceId"; + context.ProblemDetails.Extensions[traceIdKeyName] = traceId; _options.CustomizeProblemDetails?.Invoke(context); var problemDetailsType = context.ProblemDetails.GetType(); return new ValueTask(httpContext.Response.WriteAsJsonAsync( - context.ProblemDetails, - _serializerOptions.GetTypeInfo(problemDetailsType), - contentType: "application/problem+json")); + context.ProblemDetails, + _serializerOptions.GetTypeInfo(problemDetailsType), + contentType: "application/problem+json")); } } diff --git a/src/Http/Http.Extensions/test/ProblemDetailsDefaultWriterTest.cs b/src/Http/Http.Extensions/test/ProblemDetailsDefaultWriterTest.cs index d57a440a4593..130d6805a2c0 100644 --- a/src/Http/Http.Extensions/test/ProblemDetailsDefaultWriterTest.cs +++ b/src/Http/Http.Extensions/test/ProblemDetailsDefaultWriterTest.cs @@ -494,6 +494,7 @@ public async Task WriteAsync_AddExtensions() Assert.Equal("traceId", extension.Key); Assert.Equal(expectedTraceId, extension.Value.ToString()); }); + } [Fact] @@ -503,6 +504,10 @@ public async Task WriteAsync_AddExtensions_WithJsonContext() var options = new JsonOptions(); options.SerializerOptions.TypeInfoResolver = JsonTypeInfoResolver.Combine(CustomProblemDetailsContext.Default, ProblemDetailsJsonContext.Default); + var mockNamingPolicy = new Mock(); + mockNamingPolicy.Setup(policy => policy.ConvertName("traceId")).Returns("custom_traceId"); + options.SerializerOptions.PropertyNamingPolicy = mockNamingPolicy.Object; + var writer = GetWriter(jsonOptions: options); var stream = new MemoryStream(); var context = CreateContext(stream); @@ -517,10 +522,10 @@ public async Task WriteAsync_AddExtensions_WithJsonContext() ProblemDetails = expectedProblem }; - //Act + // Act await writer.WriteAsync(problemDetailsContext); - //Assert + // Assert stream.Position = 0; var problemDetails = await JsonSerializer.DeserializeAsync(stream, options.SerializerOptions); @@ -537,7 +542,7 @@ public async Task WriteAsync_AddExtensions_WithJsonContext() }, (extension) => { - Assert.Equal("traceId", extension.Key); + Assert.Equal("custom_traceId", extension.Key); // Updated to reflect the custom naming policy Assert.Equal(expectedTraceId, extension.Value.ToString()); }); } @@ -670,6 +675,146 @@ public void CanWrite_ReturnsFalse_WhenJsonNotAccepted(string contentType) Assert.False(result); } + [Fact] + public async Task WriteAsync_Respects_CustomNamingPolicy_ForTraceId() + { + // Arrange + var writer = GetWriter(); + var stream = new MemoryStream(); + var context = CreateContext(stream); + var expectedTraceId = Activity.Current?.Id ?? context.TraceIdentifier; + + var mockNamingPolicy = new Mock(); + mockNamingPolicy.Setup(policy => policy.ConvertName("traceId")).Returns("custom_traceId"); + + var serializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = mockNamingPolicy.Object }; + SetSerializerOptions(writer, serializerOptions); + + var expectedProblem = new ProblemDetails() + { + Status = StatusCodes.Status500InternalServerError + }; + + var problemDetailsContext = new ProblemDetailsContext() + { + HttpContext = context, + ProblemDetails = expectedProblem + }; + + // Act + await writer.WriteAsync(problemDetailsContext); + + // Assert + stream.Position = 0; + var problemDetails = await JsonSerializer.DeserializeAsync(stream, serializerOptions); + Assert.NotNull(problemDetails); + Assert.Equal(expectedTraceId, problemDetails.Extensions["custom_traceId"].ToString()); + Assert.DoesNotContain("traceId", problemDetails.Extensions.Keys); + } + + [Fact] + public async Task WriteAsync_FallsBack_WhenNamingPolicyReturnsNull() + { + // Arrange + var writer = GetWriter(); + var stream = new MemoryStream(); + var context = CreateContext(stream); + var expectedTraceId = Activity.Current?.Id ?? context.TraceIdentifier; + + var mockNamingPolicy = new Mock(); + mockNamingPolicy.Setup(policy => policy.ConvertName("traceId")).Returns((string)null); + + var serializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = mockNamingPolicy.Object }; + SetSerializerOptions(writer, serializerOptions); + + var expectedProblem = new ProblemDetails() + { + Status = StatusCodes.Status500InternalServerError + }; + + var problemDetailsContext = new ProblemDetailsContext() + { + HttpContext = context, + ProblemDetails = expectedProblem + }; + + // Act + await writer.WriteAsync(problemDetailsContext); + + // Assert + stream.Position = 0; + var problemDetails = await JsonSerializer.DeserializeAsync(stream, serializerOptions); + Assert.NotNull(problemDetails); + Assert.Equal(expectedTraceId, problemDetails.Extensions["traceId"].ToString()); + } + + [Fact] + public async Task WriteAsync_Respects_CustomNamingPolicy_ForValidationProblemDetails() + { + // Arrange + var writer = GetWriter(); + var stream = new MemoryStream(); + var context = CreateContext(stream); + var expectedTraceId = Activity.Current?.Id ?? context.TraceIdentifier; + + var mockNamingPolicy = new Mock(); + mockNamingPolicy.Setup(policy => policy.ConvertName("traceId")).Returns("custom_traceId"); + + var serializerOptions = new JsonSerializerOptions { PropertyNamingPolicy = mockNamingPolicy.Object }; + SetSerializerOptions(writer, serializerOptions); + + var expectedProblem = new ValidationProblemDetails() + { + Errors = new Dictionary { { "sample", new[] { "error-message" } } } + }; + + var problemDetailsContext = new ProblemDetailsContext() + { + HttpContext = context, + ProblemDetails = expectedProblem + }; + + // Act + await writer.WriteAsync(problemDetailsContext); + + // Assert + stream.Position = 0; + var problemDetails = await JsonSerializer.DeserializeAsync(stream, serializerOptions); + Assert.NotNull(problemDetails); + Assert.Equal(expectedTraceId, problemDetails.Extensions["custom_traceId"].ToString()); + Assert.DoesNotContain("traceId", problemDetails.Extensions.Keys); + } + + [Fact] + public async Task WriteAsync_Uses_DefaultTraceIdKey_WhenNoNamingPolicy() + { + // Arrange + var writer = GetWriter(); + var stream = new MemoryStream(); + var context = CreateContext(stream); + var expectedTraceId = Activity.Current?.Id ?? context.TraceIdentifier; + + var expectedProblem = new ProblemDetails() + { + Status = StatusCodes.Status500InternalServerError + }; + + var problemDetailsContext = new ProblemDetailsContext() + { + HttpContext = context, + ProblemDetails = expectedProblem + }; + + // Act + await writer.WriteAsync(problemDetailsContext); + + // Assert + stream.Position = 0; + var problemDetails = await JsonSerializer.DeserializeAsync(stream, SerializerOptions); + Assert.NotNull(problemDetails); + Assert.Equal(expectedTraceId, problemDetails.Extensions["traceId"].ToString()); + } + private static HttpContext CreateContext( Stream body, int statusCode = StatusCodes.Status400BadRequest,