Skip to content

Added possibilty to obtain HttpStatusCode from IOperationResult.ContextData #8481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/StrawberryShake/Client/src/Transport.Http/HttpConnection.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Immutable;
using System.Text;
using System.Text.Json;
using HotChocolate.Transport.Http;
Expand All @@ -10,6 +11,7 @@ public class HttpConnection : IHttpConnection
{
public const string RequestUri = "StrawberryShake.Transport.Http.HttpConnection.RequestUri";
public const string HttpClient = "StrawberryShake.Transport.Http.HttpConnection.HttpClient";
public const string HttpStatusCodeCaptureKey = "StrawberryShake.Transport.Http.HttpConnection.HttpStatusCodeCaptureKey";

private readonly Func<OperationRequest, object?, HttpClient> _createClient;
private readonly object? _clientFactoryState;
Expand Down Expand Up @@ -37,7 +39,7 @@ public IAsyncEnumerable<Response<JsonDocument>> ExecuteAsync(OperationRequest re
return Create(
CreateClient(request),
CreateHttpRequest(request),
CreateResponse);
context => CreateResponse(context, request));
}

protected virtual HttpClient CreateClient(OperationRequest request)
Expand Down Expand Up @@ -105,6 +107,34 @@ protected virtual GraphQLHttpRequest CreateHttpRequest(OperationRequest request)
return new GraphQLHttpRequest(operation) { EnableFileUploads = hasFiles };
}

private Response<JsonDocument> CreateResponse(
HttpResponseContext responseContext,
OperationRequest operationRequest)
{
// Capture the status code of the response if requested
if (operationRequest.ContextData.TryGetValue(HttpStatusCodeCaptureKey, out var key)
&& key is string stringKey
&& !string.IsNullOrEmpty(stringKey))
{
var responseContextData = responseContext.ContextData;
var mutableResponseContextData =
responseContextData as IImmutableDictionary<string, object?> ??
responseContextData?.ToImmutableDictionary() ??
ImmutableDictionary<string, object?>.Empty;

responseContext = new HttpResponseContext(
responseContext.Response,
responseContext.Body,
responseContext.Exception,
responseContext.IsPatch,
responseContext.HasNext,
responseContext.Extensions,
mutableResponseContextData.Add(stringKey, responseContext.Response.StatusCode));
}

return CreateResponse(responseContext);
}

protected virtual Response<JsonDocument> CreateResponse(
HttpResponseContext responseContext)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ private static IReadOnlyList<GeneratorResult> GenerateCSharpDocuments(
settings.NoStore,
settings.InputRecords,
settings.EntityRecords,
settings.RazorComponents);
settings.RazorComponents,
settings.GenerateWithHttpStatusCodeCaptureMethod);

var results = new List<GeneratorResult>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,10 @@ public class CSharpGeneratorSettings
subscription: TransportType.WebSocket)

];

/// <summary>
/// A value indicating whether to generate the
/// <c>WithHttpStatusCodeCapture(string key = "HttpStatusCode")</c>-method.
/// </summary>
public bool GenerateWithHttpStatusCodeCaptureMethod { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ protected override void Generate(

var serializerAssignments = UseInjectedSerializers(descriptor, privateConstructorBuilder);

foreach (var method in CreateWitherMethods(descriptor, serializerAssignments))
foreach (var method in CreateWitherMethods(descriptor, serializerAssignments, settings))
{
classBuilder.AddMethod(method);
}
Expand Down Expand Up @@ -387,9 +387,9 @@ private MethodBuilder CreateExecuteMethod(
.AddArgument("false")));
}

private static IEnumerable<MethodBuilder> CreateWitherMethods(
OperationDescriptor operationDescriptor,
string serializerAssignments)
private static IEnumerable<MethodBuilder> CreateWitherMethods(OperationDescriptor operationDescriptor,
string serializerAssignments,
CSharpSyntaxGeneratorSettings settings)
{
var withMethod = MethodBuilder
.New()
Expand Down Expand Up @@ -439,6 +439,26 @@ private static IEnumerable<MethodBuilder> CreateWitherMethods(
string.Format(
"return With(r => r.ContextData[\"{0}\"] = httpClient);" + Environment.NewLine,
"StrawberryShake.Transport.Http.HttpConnection.HttpClient")));

if (settings.GenerateWithHttpStatusCodeCaptureMethod)
{
var withHttpStatusCodeCaptureUriMethod = MethodBuilder
.New()
.SetPublic()
.SetReturnType(operationDescriptor.InterfaceType.ToString())
.SetName("WithHttpStatusCodeCapture");

withHttpStatusCodeCaptureUriMethod
.AddParameter("key")
.SetDefault("\"HttpStatusCode\"")
.SetType(TypeNames.String);

yield return withHttpStatusCodeCaptureUriMethod
.AddCode(CodeInlineBuilder.From(
string.Format(
"return With(r => r.ContextData[\"{0}\"] = key);" + Environment.NewLine,
"StrawberryShake.Transport.Http.HttpConnection.HttpStatusCodeCaptureKey")));
}
}

private static MethodBuilder CreateRequestVariablesMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ protected override void Generate(OperationDescriptor descriptor,

if (descriptor is not SubscriptionOperationDescriptor)
{
foreach (var method in CreateWitherMethods(descriptor))
foreach (var method in CreateWitherMethods(descriptor, settings))
{
interfaceBuilder.AddMethod(method);
}
Expand Down Expand Up @@ -111,8 +111,8 @@ private static MethodBuilder CreateExecuteMethod(
return executeMethod;
}

private static IEnumerable<MethodBuilder> CreateWitherMethods(
OperationDescriptor operationDescriptor)
private static IEnumerable<MethodBuilder> CreateWitherMethods(OperationDescriptor operationDescriptor,
CSharpSyntaxGeneratorSettings settings)
{
var withMethod = MethodBuilder
.New()
Expand Down Expand Up @@ -149,5 +149,21 @@ private static IEnumerable<MethodBuilder> CreateWitherMethods(
.SetType("global::System.Net.Http.HttpClient");

yield return withHttpClientMethod;

if (settings.GenerateWithHttpStatusCodeCaptureMethod)
{
var withHttpStatusCodeCaptureUriMethod = MethodBuilder
.New()
.SetOnlyDeclaration()
.SetReturnType(operationDescriptor.InterfaceType.ToString())
.SetName("WithHttpStatusCodeCapture");

withHttpStatusCodeCaptureUriMethod
.AddParameter("key")
.SetDefault("\"HttpStatusCode\"")
.SetType(TypeNames.String);

yield return withHttpStatusCodeCaptureUriMethod;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ public CSharpSyntaxGeneratorSettings(
bool noStore,
bool inputRecords,
bool entityRecords,
bool razorComponents)
bool razorComponents,
bool generateWithHttpStatusCodeCaptureMethod)
{
AccessModifier = accessModifier;
NoStore = noStore;
InputRecords = inputRecords;
EntityRecords = entityRecords;
RazorComponents = razorComponents;
GenerateWithHttpStatusCodeCaptureMethod = generateWithHttpStatusCodeCaptureMethod;
}

/// <summary>
Expand All @@ -46,4 +48,10 @@ public CSharpSyntaxGeneratorSettings(
/// Generate Razor components.
/// </summary>
public bool RazorComponents { get; }

/// <summary>
/// A value indicating whether to generate the
/// <c>WithHttpStatusCodeCapture(string key = "HttpStatusCode")</c>-method.
/// </summary>
public bool GenerateWithHttpStatusCodeCaptureMethod { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ public static void AssertResult(
NoStore = settings.NoStore,
InputRecords = settings.InputRecords,
EntityRecords = settings.EntityRecords,
RazorComponents = settings.RazorComponents
RazorComponents = settings.RazorComponents,
GenerateWithHttpStatusCodeCaptureMethod = settings.GenerateWithHttpStatusCodeCaptureMethod
});

Assert.False(
Expand Down Expand Up @@ -198,6 +199,7 @@ public static AssertSettings CreateIntegrationTest(
TransportProfile[]? profiles = null,
AccessModifier accessModifier = AccessModifier.Public,
bool noStore = false,
bool generateWithHttpStatusCodeCaptureMethod = false,
[CallerMemberName] string? testName = null)
{
var snapshotFullName = Snapshot.FullName();
Expand Down Expand Up @@ -226,6 +228,7 @@ public static AssertSettings CreateIntegrationTest(
testName + "Test.Client.cs"),
RequestStrategy = requestStrategy,
NoStore = noStore,
GenerateWithHttpStatusCodeCaptureMethod = generateWithHttpStatusCodeCaptureMethod,
Profiles = (profiles ??
[
TransportProfile.Default
Expand Down Expand Up @@ -282,5 +285,7 @@ public class AssertSettings

public RequestStrategyGen RequestStrategy { get; set; } =
RequestStrategyGen.Default;

public bool GenerateWithHttpStatusCodeCaptureMethod { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,11 @@ private GetHeroQuery(global::StrawberryShake.IOperationExecutor<IGetHeroResult>
return With(r => r.ContextData["StrawberryShake.Transport.Http.HttpConnection.HttpClient"] = httpClient);
}

public global::StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.IGetHeroQuery WithHttpStatusCodeCapture(global::System.String key = "HttpStatusCode")
{
return With(r => r.ContextData["StrawberryShake.Transport.Http.HttpConnection.HttpStatusCodeCaptureKey"] = key);
}

public async global::System.Threading.Tasks.Task<global::StrawberryShake.IOperationResult<IGetHeroResult>> ExecuteAsync(global::System.Threading.CancellationToken cancellationToken = default)
{
var request = CreateRequest();
Expand Down Expand Up @@ -575,6 +580,7 @@ public partial interface IGetHeroQuery : global::StrawberryShake.IOperationReque
global::StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.IGetHeroQuery With(global::System.Action<global::StrawberryShake.OperationRequest> configure);
global::StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.IGetHeroQuery WithRequestUri(global::System.Uri requestUri);
global::StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.IGetHeroQuery WithHttpClient(global::System.Net.Http.HttpClient httpClient);
global::StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.IGetHeroQuery WithHttpStatusCodeCapture(global::System.String key = "HttpStatusCode");
global::System.Threading.Tasks.Task<global::StrawberryShake.IOperationResult<IGetHeroResult>> ExecuteAsync(global::System.Threading.CancellationToken cancellationToken = default);
global::System.IObservable<global::StrawberryShake.IOperationResult<IGetHeroResult>> Watch(global::StrawberryShake.ExecutionStrategy? strategy = null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System.Net;
using HotChocolate.AspNetCore.Tests.Utilities;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.DependencyInjection;
using StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero.State;
using StrawberryShake.Transport.WebSockets;
using StrawberryShake.Persistence.SQLite;
using StrawberryShake.Transport.Http;

namespace StrawberryShake.CodeGeneration.CSharp.Integration.StarWarsGetHero;

Expand Down Expand Up @@ -40,6 +42,41 @@ public async Task Execute_StarWarsGetHero_Test()
Assert.Equal("R2-D2", result.Data!.Hero!.Name);
}

[Theory]
[InlineData(HttpStatusCode.OK)]
[InlineData(HttpStatusCode.Forbidden)]
public async Task Execute_StarWarsGetHero_ShouldCaptureExpectedStatusCode(HttpStatusCode expectedStatusCode)
{
// Arrange
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));
using var host = TestServerHelper.CreateServer(_ => { }, out var port);

var services = new ServiceCollection();

services
.AddStarWarsGetHeroClient()
.ConfigureHttpClient(client =>
{
if (expectedStatusCode == HttpStatusCode.Forbidden)
{
client.DefaultRequestHeaders.Add("sendErrorStatusCode", "1");
}
});

var serviceProvider = services.BuildServiceProvider();
var client = serviceProvider.GetRequiredService<StarWarsGetHeroClient>();

// Act
var result = await client.GetHero
.WithRequestUri(new Uri($"http://localhost:{port}/graphql"))
.WithHttpStatusCodeCapture(key: "foo")
.ExecuteAsync(cts.Token);

// Assert
Assert.Equal("R2-D2", result.Data!.Hero!.Name);
Assert.Equal(expectedStatusCode, result.ContextData["foo"]);
}

[Fact]
public async Task Watch_StarWarsGetHero_Test()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class TestGeneration
[Fact]
public void StarWarsGetHero() =>
AssertStarWarsResult(
CreateIntegrationTest(),
CreateIntegrationTest(generateWithHttpStatusCodeCaptureMethod: true),
@"query GetHero {
hero(episode: NEW_HOPE) {
name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public static void AssertResult(
NoStore = settings.NoStore,
InputRecords = settings.InputRecords,
EntityRecords = settings.EntityRecords,
RazorComponents = settings.RazorComponents
RazorComponents = settings.RazorComponents,
GenerateWithHttpStatusCodeCaptureMethod = settings.GenerateWithHttpStatusCodeCaptureMethod
});

Assert.False(
Expand Down Expand Up @@ -276,5 +277,7 @@ public class AssertSettings

public RequestStrategyGen RequestStrategy { get; set; } =
RequestStrategyGen.Default;

public bool GenerateWithHttpStatusCodeCaptureMethod { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ public class StrawberryShakeSettings
/// </summary>
public bool? RazorComponents { get; set; }

/// <summary>
/// Defines if the generator shall generate the
/// <c>WithHttpStatusCodeCapture(string key = "HttpStatusCode")</c>-method
/// that can be used to capture the HTTP status code.
/// </summary>
public bool? GenerateWithHttpStatusCodeCaptureMethod { get; set; }

/// <summary>
/// Gets the record generator settings.
/// </summary>
Expand Down
14 changes: 12 additions & 2 deletions src/StrawberryShake/Tooling/src/dotnet-graphql/GenerateCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public static void Build(CommandLineApplication generate)
"Console output as JSON.",
CommandOptionType.NoValue);

var generateWithHttpStatusCodeCaptureMethodArg = generate.Option(
"--generateWithHttpStatusCodeCaptureMethod",
"Generate the 'WithHttpStatusCodeCapture'-method.",
CommandOptionType.NoValue);

generate.OnExecuteAsync(
ct =>
{
Expand All @@ -86,7 +91,8 @@ public static void Build(CommandLineApplication generate)
outputDirArg.Value(),
strategy,
operationOutputDir,
relayFormatArg.HasValue());
relayFormatArg.HasValue(),
generateWithHttpStatusCodeCaptureMethodArg.HasValue());
var handler = CommandTools.CreateHandler<GenerateCommandHandler>(jsonArg);
return handler.ExecuteAsync(arguments, ct);
});
Expand Down Expand Up @@ -321,7 +327,8 @@ public GenerateCommandArguments(
string? outputDir,
RequestStrategy strategy,
string? operationOutputDir,
bool relayFormat)
bool relayFormat,
bool generateWithHttpStatusCodeCaptureMethod)
{
Path = path;
RootNamespace = rootNamespace;
Expand All @@ -334,6 +341,7 @@ public GenerateCommandArguments(
Strategy = strategy;
RelayFormat = relayFormat;
OperationOutputDir = operationOutputDir;
GenerateWithHttpStatusCodeCaptureMethod = generateWithHttpStatusCodeCaptureMethod;

if (operationOutputDir is null && outputDir is not null)
{
Expand Down Expand Up @@ -362,5 +370,7 @@ public GenerateCommandArguments(
public string? OperationOutputDir { get; }

public bool RelayFormat { get; }

public bool GenerateWithHttpStatusCodeCaptureMethod { get; }
}
}
Loading