Skip to content

Commit 836e0f0

Browse files
[release/7.0] gRPC JSON transcoding: Fix known type messages in querystring values (#46752)
* Fix support known type messages in querystring values for gRPC JSON transcoding More tests Update before rebase * Clean up --------- Co-authored-by: James Newton-King <[email protected]>
1 parent 9a94197 commit 836e0f0

File tree

14 files changed

+216
-19
lines changed

14 files changed

+216
-19
lines changed

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.JsonTranscoding/Internal/Json/DurationConverter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text.Json;
55
using Google.Protobuf;
66
using Google.Protobuf.WellKnownTypes;
7+
using Grpc.Shared;
78
using Type = System.Type;
89

910
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.JsonTranscoding/Internal/Json/EnumConverter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Runtime.CompilerServices;
66
using System.Text.Json;
77
using Google.Protobuf.Reflection;
8+
using Grpc.Shared;
89
using Type = System.Type;
910

1011
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.JsonTranscoding/Internal/Json/FieldMaskConverter.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Text.Json;
77
using Google.Protobuf;
88
using Google.Protobuf.WellKnownTypes;
9+
using Grpc.Shared;
910
using Type = System.Type;
1011

1112
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;
@@ -43,7 +44,7 @@ public override void Write(Utf8JsonWriter writer, TMessage value, JsonSerializer
4344
var firstInvalid = paths.FirstOrDefault(p => !Legacy.IsPathValid(p));
4445
if (firstInvalid == null)
4546
{
46-
writer.WriteStringValue(string.Join(",", paths.Select(Legacy.ToJsonName)));
47+
writer.WriteStringValue(Legacy.GetFieldMaskText(paths));
4748
}
4849
else
4950
{

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.JsonTranscoding/Internal/Json/TimestampConverter.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text.Json;
55
using Google.Protobuf;
66
using Google.Protobuf.WellKnownTypes;
7+
using Grpc.Shared;
78
using Type = System.Type;
89

910
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.JsonTranscoding/Microsoft.AspNetCore.Grpc.JsonTranscoding.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
<Compile Include="..\Shared\X509CertificateHelpers.cs" Link="Internal\Shared\X509CertificateHelpers.cs" />
2424
<Compile Include="..\Shared\HttpRoutePattern.cs" Link="Internal\Shared\HttpRoutePattern.cs" />
2525
<Compile Include="..\Shared\HttpRoutePatternParser.cs" Link="Internal\Shared\HttpRoutePatternParser.cs" />
26+
<Compile Include="..\Shared\Legacy.cs" Link="Internal\Shared\Legacy.cs" />
2627
<Compile Include="$(SharedSourceRoot)ValueTaskExtensions\**\*.cs" LinkBase="Internal\Shared" />
2728

2829
<Reference Include="Google.Api.CommonProtos" />

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.Swagger/Internal/GrpcDataContractResolver.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Diagnostics.CodeAnalysis;
45
using System.Linq;
56
using System.Reflection;
67
using Google.Protobuf;
@@ -65,44 +66,62 @@ public DataContract GetDataContractForType(Type type)
6566
return _innerContractResolver.GetDataContractForType(type);
6667
}
6768

68-
private DataContract ConvertMessage(MessageDescriptor messageDescriptor)
69+
private bool TryCustomizeMessage(MessageDescriptor messageDescriptor, [NotNullWhen(true)] out DataContract? dataContract)
6970
{
71+
// The messages serialized here should be kept in sync with SericeDescriptionHelper.IsCustomType.
7072
if (ServiceDescriptorHelpers.IsWellKnownType(messageDescriptor))
7173
{
7274
if (ServiceDescriptorHelpers.IsWrapperType(messageDescriptor))
7375
{
7476
var field = messageDescriptor.Fields[Int32Value.ValueFieldNumber];
7577

76-
return _innerContractResolver.GetDataContractForType(MessageDescriptorHelpers.ResolveFieldType(field));
78+
dataContract = _innerContractResolver.GetDataContractForType(MessageDescriptorHelpers.ResolveFieldType(field));
79+
return true;
7780
}
7881
if (messageDescriptor.FullName == Timestamp.Descriptor.FullName ||
7982
messageDescriptor.FullName == Duration.Descriptor.FullName ||
8083
messageDescriptor.FullName == FieldMask.Descriptor.FullName)
8184
{
82-
return DataContract.ForPrimitive(messageDescriptor.ClrType, DataType.String, dataFormat: null);
85+
dataContract = DataContract.ForPrimitive(messageDescriptor.ClrType, DataType.String, dataFormat: null);
86+
return true;
8387
}
8488
if (messageDescriptor.FullName == Struct.Descriptor.FullName)
8589
{
86-
return DataContract.ForObject(messageDescriptor.ClrType, Array.Empty<DataProperty>(), extensionDataType: typeof(Value));
90+
dataContract = DataContract.ForObject(messageDescriptor.ClrType, Array.Empty<DataProperty>(), extensionDataType: typeof(Value));
91+
return true;
8792
}
8893
if (messageDescriptor.FullName == ListValue.Descriptor.FullName)
8994
{
90-
return DataContract.ForArray(messageDescriptor.ClrType, typeof(Value));
95+
dataContract = DataContract.ForArray(messageDescriptor.ClrType, typeof(Value));
96+
return true;
9197
}
9298
if (messageDescriptor.FullName == Value.Descriptor.FullName)
9399
{
94-
return DataContract.ForPrimitive(messageDescriptor.ClrType, DataType.Unknown, dataFormat: null);
100+
dataContract = DataContract.ForPrimitive(messageDescriptor.ClrType, DataType.Unknown, dataFormat: null);
101+
return true;
95102
}
96103
if (messageDescriptor.FullName == Any.Descriptor.FullName)
97104
{
98105
var anyProperties = new List<DataProperty>
99106
{
100107
new DataProperty("@type", typeof(string), isRequired: true)
101108
};
102-
return DataContract.ForObject(messageDescriptor.ClrType, anyProperties, extensionDataType: typeof(Value));
109+
dataContract = DataContract.ForObject(messageDescriptor.ClrType, anyProperties, extensionDataType: typeof(Value));
110+
return true;
103111
}
104112
}
105113

114+
dataContract = null;
115+
return false;
116+
}
117+
118+
private DataContract ConvertMessage(MessageDescriptor messageDescriptor)
119+
{
120+
if (TryCustomizeMessage(messageDescriptor, out var dataContract))
121+
{
122+
return dataContract;
123+
}
124+
106125
var properties = new List<DataProperty>();
107126

108127
foreach (var field in messageDescriptor.Fields.InFieldNumberOrder())

src/Grpc/JsonTranscoding/src/Microsoft.AspNetCore.Grpc.Swagger/Microsoft.AspNetCore.Grpc.Swagger.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22
<PropertyGroup>
33
<Description>Swagger for gRPC ASP.NET Core</Description>
44
<PackageTags>gRPC RPC HTTP/2 REST Swagger OpenAPI</PackageTags>
@@ -12,6 +12,7 @@
1212
<Compile Include="..\Shared\ServiceDescriptorHelpers.cs" Link="Internal\Shared\ServiceDescriptorHelpers.cs" />
1313
<Compile Include="..\Shared\HttpRoutePattern.cs" Link="Internal\Shared\HttpRoutePattern.cs" />
1414
<Compile Include="..\Shared\HttpRoutePatternParser.cs" Link="Internal\Shared\HttpRoutePatternParser.cs" />
15+
<Compile Include="..\Shared\Legacy.cs" Link="Internal\Shared\Legacy.cs" />
1516

1617
<Reference Include="Microsoft.AspNetCore.Grpc.JsonTranscoding" />
1718
<Reference Include="Swashbuckle.AspNetCore" />
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
using Google.Protobuf.WellKnownTypes;
4343
using Type = System.Type;
4444

45-
namespace Microsoft.AspNetCore.Grpc.JsonTranscoding.Internal.Json;
45+
namespace Grpc.Shared;
4646

4747
// Source here is from https://github.com/protocolbuffers/protobuf
4848
// Most of this code will be replaced over time with optimized implementations.
@@ -237,6 +237,11 @@ public static string GetDurationText(int nanos, long seconds)
237237
}
238238
}
239239

240+
public static string GetFieldMaskText(IList<string> paths)
241+
{
242+
return string.Join(",", paths.Select(ToJsonName));
243+
}
244+
240245
/// <summary>
241246
/// Appends a number of nanoseconds to a StringBuilder. Either 0 digits are added (in which
242247
/// case no "." is appended), or 3 6 or 9 digits. This is internal for use in Timestamp as well

src/Grpc/JsonTranscoding/src/Shared/ServiceDescriptorHelpers.cs

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,39 @@ public static bool TryResolveDescriptors(MessageDescriptor messageDescriptor, IL
148148
throw new InvalidOperationException("String required to convert to enum.");
149149
}
150150
case FieldType.Message:
151-
if (IsWrapperType(descriptor.MessageType))
151+
if (IsWellKnownType(descriptor.MessageType))
152152
{
153-
if (value == null)
153+
if (IsWrapperType(descriptor.MessageType))
154154
{
155-
return null;
155+
if (value == null)
156+
{
157+
return null;
158+
}
159+
160+
return ConvertValue(value, descriptor.MessageType.FindFieldByName("value"));
161+
}
162+
else if (descriptor.MessageType.FullName == FieldMask.Descriptor.FullName)
163+
{
164+
return FieldMask.FromString((string)value!);
156165
}
166+
else if (descriptor.MessageType.FullName == Duration.Descriptor.FullName)
167+
{
168+
var (seconds, nanos) = Legacy.ParseDuration((string)value!);
157169

158-
return ConvertValue(value, descriptor.MessageType.FindFieldByName("value"));
170+
var duration = new Duration();
171+
duration.Seconds = seconds;
172+
duration.Nanos = nanos;
173+
return duration;
174+
}
175+
else if (descriptor.MessageType.FullName == Timestamp.Descriptor.FullName)
176+
{
177+
var (seconds, nanos) = Legacy.ParseTimestamp((string)value!);
178+
179+
var timestamp = new Timestamp();
180+
timestamp.Seconds = seconds;
181+
timestamp.Nanos = nanos;
182+
return timestamp;
183+
}
159184
}
160185
break;
161186
}
@@ -247,15 +272,15 @@ public static void SetValue(IMessage message, FieldDescriptor field, object? val
247272
}
248273
}
249274

250-
public static bool TryGetHttpRule(MethodDescriptor methodDescriptor, [NotNullWhen(true)]out HttpRule? httpRule)
275+
public static bool TryGetHttpRule(MethodDescriptor methodDescriptor, [NotNullWhen(true)] out HttpRule? httpRule)
251276
{
252277
var options = methodDescriptor.GetOptions();
253278
httpRule = options?.GetExtension(AnnotationsExtensions.Http);
254279

255280
return httpRule != null;
256281
}
257282

258-
public static bool TryResolvePattern(HttpRule http, [NotNullWhen(true)]out string? pattern, [NotNullWhen(true)]out string? verb)
283+
public static bool TryResolvePattern(HttpRule http, [NotNullWhen(true)] out string? pattern, [NotNullWhen(true)] out string? verb)
259284
{
260285
switch (http.PatternCase)
261286
{
@@ -424,14 +449,21 @@ static void RecursiveVisitMessages(Dictionary<string, FieldDescriptor> queryPara
424449
case FieldType.SInt32:
425450
case FieldType.SInt64:
426451
case FieldType.Enum:
427-
var joinedPath = string.Join(".", path.Select(d => d.JsonName));
428-
queryParameters[joinedPath] = fieldDescriptor;
452+
{
453+
var joinedPath = string.Join(".", path.Select(d => d.JsonName));
454+
queryParameters[joinedPath] = fieldDescriptor;
455+
}
429456
break;
430457
case FieldType.Group:
431458
case FieldType.Message:
432459
default:
433460
// Complex repeated fields aren't valid query parameters.
434-
if (!fieldDescriptor.IsRepeated)
461+
if (IsCustomType(fieldDescriptor.MessageType))
462+
{
463+
var joinedPath = string.Join(".", path.Select(d => d.JsonName));
464+
queryParameters[joinedPath] = fieldDescriptor;
465+
}
466+
else if (!fieldDescriptor.IsRepeated)
435467
{
436468
RecursiveVisitMessages(queryParameters, existingParameters, fieldDescriptor.MessageType, path);
437469
}
@@ -444,6 +476,23 @@ static void RecursiveVisitMessages(Dictionary<string, FieldDescriptor> queryPara
444476
}
445477
}
446478

479+
private static bool IsCustomType(MessageDescriptor messageDescriptor)
480+
{
481+
// The messages flags here should be kept in sync with GrpcDataContractResolver.TryCustomizeMessage.
482+
if (IsWrapperType(messageDescriptor) ||
483+
messageDescriptor.FullName == Timestamp.Descriptor.FullName ||
484+
messageDescriptor.FullName == Duration.Descriptor.FullName ||
485+
messageDescriptor.FullName == FieldMask.Descriptor.FullName ||
486+
messageDescriptor.FullName == Struct.Descriptor.FullName ||
487+
messageDescriptor.FullName == ListValue.Descriptor.FullName ||
488+
messageDescriptor.FullName == Value.Descriptor.FullName ||
489+
messageDescriptor.FullName == Any.Descriptor.FullName)
490+
{
491+
return true;
492+
}
493+
return false;
494+
}
495+
447496
public sealed record BodyDescriptorInfo(
448497
MessageDescriptor Descriptor,
449498
FieldDescriptor? FieldDescriptor,

src/Grpc/JsonTranscoding/test/Microsoft.AspNetCore.Grpc.JsonTranscoding.Tests/UnaryServerCallHandlerTests.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,70 @@ public async Task HandleCallAsync_Any_Success()
12301230
Assert.Equal("A value!", anyMessage.GetProperty("value").GetString());
12311231
}
12321232

1233+
[Fact]
1234+
public async Task HandleCallAsync_MatchingQueryStringValues_CustomDeserialization_SetOnRequestMessage()
1235+
{
1236+
// Arrange
1237+
HelloRequest? request = null;
1238+
UnaryServerMethod<JsonTranscodingGreeterService, HelloRequest, HelloReply> invoker = (s, r, c) =>
1239+
{
1240+
request = r;
1241+
return Task.FromResult(new HelloReply());
1242+
};
1243+
1244+
var timestamp = Timestamp.FromDateTimeOffset(new DateTimeOffset(2023, 2, 14, 17, 32, 0, TimeSpan.FromHours(8)));
1245+
var duration = Duration.FromTimeSpan(TimeSpan.FromHours(1));
1246+
var fieldmask = FieldMask.FromString("one,two,three.sub");
1247+
1248+
var unaryServerCallHandler = CreateCallHandler(invoker);
1249+
var httpContext = TestHelpers.CreateHttpContext();
1250+
httpContext.Request.Query = new QueryCollection(new Dictionary<string, StringValues>
1251+
{
1252+
["timestamp_value"] = Legacy.GetTimestampText(timestamp.Nanos, timestamp.Seconds),
1253+
["duration_value"] = Legacy.GetDurationText(duration.Nanos, duration.Seconds),
1254+
["field_mask_value"] = Legacy.GetFieldMaskText(fieldmask.Paths),
1255+
["float_value"] = "1.5"
1256+
});
1257+
1258+
// Act
1259+
await unaryServerCallHandler.HandleCallAsync(httpContext);
1260+
1261+
// Assert
1262+
Assert.NotNull(request);
1263+
Assert.Equal(timestamp, request!.TimestampValue);
1264+
Assert.Equal(duration, request!.DurationValue);
1265+
Assert.Equal(fieldmask, request!.FieldMaskValue);
1266+
Assert.Equal(1.5f, request!.FloatValue);
1267+
}
1268+
1269+
[Fact]
1270+
public async Task HandleCallAsync_MatchingQueryStringValues_KnownType_FieldSetter_SetOnRequestMessage()
1271+
{
1272+
// Arrange
1273+
HelloRequest? request = null;
1274+
UnaryServerMethod<JsonTranscodingGreeterService, HelloRequest, HelloReply> invoker = (s, r, c) =>
1275+
{
1276+
request = r;
1277+
return Task.FromResult(new HelloReply());
1278+
};
1279+
1280+
var fieldmask = FieldMask.FromString("one,two,three.sub");
1281+
1282+
var unaryServerCallHandler = CreateCallHandler(invoker);
1283+
var httpContext = TestHelpers.CreateHttpContext();
1284+
httpContext.Request.Query = new QueryCollection(new Dictionary<string, StringValues>
1285+
{
1286+
["field_mask_value.paths"] = new StringValues(fieldmask.Paths.ToArray()),
1287+
});
1288+
1289+
// Act
1290+
await unaryServerCallHandler.HandleCallAsync(httpContext);
1291+
1292+
// Assert
1293+
Assert.NotNull(request);
1294+
Assert.Equal(fieldmask, request!.FieldMaskValue);
1295+
}
1296+
12331297
private UnaryServerCallHandler<JsonTranscodingGreeterService, HelloRequest, HelloReply> CreateCallHandler(
12341298
UnaryServerMethod<JsonTranscodingGreeterService, HelloRequest, HelloReply> invoker,
12351299
CallHandlerDescriptorInfo? descriptorInfo = null,

0 commit comments

Comments
 (0)