Skip to content

Commit 0e04570

Browse files
committed
add CallContext.ResponseHeadersAsync() - more reliable way of getting headers in the middle of requests
1 parent 557d06c commit 0e04570

File tree

3 files changed

+63
-12
lines changed

3 files changed

+63
-12
lines changed

src/protobuf-net.Grpc/CallContext.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.ComponentModel;
55
using System.Runtime.CompilerServices;
66
using System.Threading;
7+
using System.Threading.Tasks;
78

89
namespace ProtoBuf.Grpc
910
{
@@ -161,6 +162,13 @@ public CallContext(in CallOptions callOptions = default, CallContextFlags flags
161162
[MethodImpl(MethodImplOptions.AggressiveInlining)]
162163
public Metadata ResponseHeaders() => MetadataContext?.Headers ?? ThrowNoContext<Metadata>();
163164

165+
/// <summary>
166+
/// Get the response-headers from a client operation when they are available
167+
/// </summary>
168+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
169+
public Task<Metadata> ResponseHeadersAsync() => MetadataContext?.GetHeadersTask(true)
170+
?? Task.FromResult(ThrowNoContext<Metadata>()); // note this actually throws immediately; this is intentional
171+
164172
/// <summary>
165173
/// Get the response-trailers from a client operation
166174
/// </summary>

src/protobuf-net.Grpc/Internal/MetadataContext.cs

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Runtime.CompilerServices;
3+
using System.Threading;
34
using System.Threading.Tasks;
45
using Grpc.Core;
56

@@ -10,11 +11,42 @@ internal sealed class MetadataContext
1011
internal MetadataContext(object? state) => State = state;
1112

1213
internal object? State { get; }
13-
private Metadata? _headers, _trailers;
14+
private Metadata? _trailers;
15+
private object? _headersTaskOrSource;
16+
1417
internal Metadata Headers
1518
{
16-
get => _headers ?? Throw("Headers are not yet available");
19+
get
20+
{
21+
var pending = GetHeadersTask(false);
22+
return pending is object && pending.RanToCompletion()
23+
? pending.Result
24+
: Throw("Headers are not yet available");
25+
}
26+
}
27+
28+
internal Task<Metadata>? GetHeadersTask(bool createIfMissing)
29+
{
30+
return _headersTaskOrSource switch
31+
{
32+
Task<Metadata> task => task,
33+
TaskCompletionSource<Metadata> tcs => tcs.Task,
34+
_ => createIfMissing ? InterlockedCreateSource() : null,
35+
};
36+
37+
Task<Metadata> InterlockedCreateSource()
38+
{
39+
var newTcs = new TaskCompletionSource<Metadata>();
40+
var existing = Interlocked.CompareExchange(ref _headersTaskOrSource, newTcs, null);
41+
return existing switch
42+
{
43+
Task<Metadata> task => task,
44+
TaskCompletionSource<Metadata> tcs => tcs.Task,
45+
_ => newTcs.Task,
46+
};
47+
}
1748
}
49+
1850
internal Metadata Trailers
1951
{
2052
get => _trailers ?? Throw("Trailers are not yet available");
@@ -27,7 +59,8 @@ internal Metadata Trailers
2759
internal MetadataContext Reset()
2860
{
2961
Status = Status.DefaultSuccess;
30-
_headers = _trailers = null;
62+
_trailers = null;
63+
_headersTaskOrSource = null;
3164
return this;
3265
}
3366

@@ -57,24 +90,34 @@ internal void SetTrailers<T>(T call, Func<T, Status> getStatus, Func<T, Metadata
5790

5891
internal ValueTask SetHeadersAsync(Task<Metadata> headers)
5992
{
93+
var tcs = Interlocked.CompareExchange(ref _headersTaskOrSource, headers, null) as TaskCompletionSource<Metadata>;
6094
if (headers.RanToCompletion())
6195
{
62-
_headers = headers.Result;
96+
// headers are sync; update TCS if one
97+
tcs?.TrySetResult(headers.Result);
6398
return default;
6499
}
65100
else
66101
{
67-
return Awaited(this, headers);
102+
// headers are async (or faulted); pay the piper
103+
return Awaited(this, tcs, headers);
68104
}
69-
static async ValueTask Awaited(MetadataContext context, Task<Metadata> headers)
105+
106+
static async ValueTask Awaited(MetadataContext context, TaskCompletionSource<Metadata>? tcs, Task<Metadata> headers)
70107
{
71108
try
72109
{
73-
context._headers = await headers.ConfigureAwait(false);
110+
tcs?.TrySetResult(await headers.ConfigureAwait(false));
74111
}
75112
catch (RpcException fault)
76113
{
77114
context.SetTrailers(fault);
115+
tcs?.TrySetException(fault);
116+
throw;
117+
}
118+
catch (Exception ex)
119+
{
120+
tcs?.TrySetException(ex);
78121
throw;
79122
}
80123
}

tests/protobuf-net.Grpc.Test.Integration/StreamTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ public async Task DuplexEchoFault(Scenario scenario, int expectedCount, string m
517517
{
518518
await foreach (var item in client.DuplexEcho(For(scenario, DEFAULT_SIZE), ctx))
519519
{
520-
CheckHeaderState();
520+
await CheckHeaderStateAsync();
521521
values.Add(item.Bar);
522522
}
523523
});
@@ -526,7 +526,7 @@ public async Task DuplexEchoFault(Scenario scenario, int expectedCount, string m
526526
Assert.Equal(marker + " faultval", rpc.Trailers.GetString("faultkey"));
527527

528528
_fixture?.Log("after await foreach");
529-
CheckHeaderState();
529+
await CheckHeaderStateAsync();
530530
Assert.Equal(string.Join(",", Enumerable.Range(0, expectedCount)), string.Join(",", values));
531531

532532
if ((flags & CallContextFlags.CaptureMetadata) != 0)
@@ -538,7 +538,7 @@ public async Task DuplexEchoFault(Scenario scenario, int expectedCount, string m
538538
Assert.Equal(marker + " detail", status.Detail);
539539
}
540540

541-
void CheckHeaderState()
541+
async ValueTask CheckHeaderStateAsync()
542542
{
543543
if (haveCheckedHeaders) return;
544544
haveCheckedHeaders = true;
@@ -548,10 +548,10 @@ void CheckHeaderState()
548548
switch (scenario)
549549
{
550550
case Scenario.FaultBeforeHeaders:
551-
Assert.Null(ctx.ResponseHeaders().GetString("prekey"));
551+
Assert.Null((await ctx.ResponseHeadersAsync()).GetString("prekey"));
552552
break;
553553
default:
554-
Assert.Equal("preval", ctx.ResponseHeaders().GetString("prekey"));
554+
Assert.Equal("preval", (await ctx.ResponseHeadersAsync()).GetString("prekey"));
555555
break;
556556
}
557557
}

0 commit comments

Comments
 (0)