Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System;
using System.Collections.Generic;
using System.Net.Http;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand All @@ -20,6 +20,8 @@ public BinaryPayloadContentBuilder(IReadOnlyList<IHubProtocol> hubProtocols)
_hubProtocols = hubProtocols;
}

public ObjectSerializer? ObjectSerializer => null;

public HttpContent? Build(HubMessage? payload, Type? typeHint)
{
return payload == null ? null : (HttpContent)new BinaryPayloadMessageContent(payload, _hubProtocols);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

using System;
using System.Net.Http;

using Azure.Core.Serialization;
using Microsoft.AspNetCore.SignalR.Protocol;

#nullable enable
Expand All @@ -13,4 +13,6 @@ namespace Microsoft.Azure.SignalR.Common;
internal interface IPayloadContentBuilder
{
HttpContent? Build(HubMessage? payload, Type? typeHint);

ObjectSerializer? ObjectSerializer { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public JsonPayloadContentBuilder(ObjectSerializer jsonObjectSerializer)
{
return payload == null ? null : new JsonPayloadMessageContent(payload, _jsonObjectSerializer, typeHint);
}

public ObjectSerializer? ObjectSerializer => _jsonObjectSerializer;
}
8 changes: 6 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ internal class RestClient

private readonly IPayloadContentBuilder _payloadContentBuilder;

private static readonly ObjectSerializer DefaultObjectSerializer = new JsonObjectSerializer();

public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder)
{
_httpClientFactory = httpClientFactory;
Expand Down Expand Up @@ -81,10 +83,10 @@ public Task SendMessageWithRetryAsync(
HttpMethod httpMethod,
string methodName,
object?[] args,
Func<HttpResponseMessage, bool>? handleExpectedResponse = null,
Func<HttpResponseMessage, Task<bool>>? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, AsAsync(handleExpectedResponse), cancellationToken);
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new InvocationMessage(methodName, args), null, handleExpectedResponse, cancellationToken);
}

public Task SendStreamMessageWithRetryAsync(
Expand All @@ -99,6 +101,8 @@ public Task SendStreamMessageWithRetryAsync(
return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, new StreamItemMessage(streamId, arg), typeHint, AsAsync(handleExpectedResponse), cancellationToken);
}

public ObjectSerializer ObjectSerializer => _payloadContentBuilder.ObjectSerializer ?? DefaultObjectSerializer;

private static Uri GetUri(string url, IDictionary<string, StringValues>? query)
{
if (query == null || query.Count == 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Text.Json.Serialization;

namespace Microsoft.Azure.SignalR.Management.ClientInvocation;

sealed class InvocationResponse<T>
{
[JsonPropertyName("result")]
public T Result { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public IServiceHubLifetimeManager<THub> Create<THub>(string hubName) where THub
var httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
var serviceEndpoint = _serviceProvider.GetRequiredService<IServiceEndpointManager>().Endpoints.First().Key;
var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient);
var protocolResolver = _serviceProvider.GetRequiredService<IHubProtocolResolver>();
return new RestHubLifetimeManager<THub>(hubName, serviceEndpoint, _options.ApplicationName, restClient, protocolResolver);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Management/RestApiProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public RestApiEndpoint GetListConnectionsInGroupEndpoint(string appName, string
return GenerateRestApiEndpoint(appName, hubName, $"/groups/{Uri.EscapeDataString(groupName)}/connections");
}

public RestApiEndpoint SendClientInvocation(string appName, string hubName, string connectionId)
{
return GenerateRestApiEndpoint(appName, hubName, $"/connections/{Uri.EscapeDataString(connectionId)}/:invoke");
}

private RestApiEndpoint GenerateRestApiEndpoint(string appName, string hubName, string pathAfterHub, IDictionary<string, StringValues> queries = null)
{
var requestPrefixWithHub = $"{_serverEndpoint}api/hubs/{Uri.EscapeDataString(hubName.ToLowerInvariant())}";
Expand Down
92 changes: 91 additions & 1 deletion src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
using Azure;

using Microsoft.AspNetCore.SignalR;
#if NET7_0_OR_GREATER
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Management.ClientInvocation;
#endif
using Microsoft.Extensions.Primitives;

using static Microsoft.Azure.SignalR.Constants;
Expand All @@ -31,13 +35,15 @@ internal class RestHubLifetimeManager<THub> : HubLifetimeManager<THub>, IService
private readonly RestApiProvider _restApiProvider;
private readonly string _hubName;
private readonly string _appName;
private readonly IHubProtocolResolver _protocolResolver;

public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient, IHubProtocolResolver protocolResolver)
{
_restApiProvider = new RestApiProvider(endpoint);
_appName = appName;
_hubName = hubName;
_restClient = restClient;
_protocolResolver = protocolResolver;
}

public override async Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -353,6 +359,90 @@ public async Task SendStreamCompletionAsync(string connectionId, string streamId
await _restClient.SendWithRetryAsync(api, HttpMethod.Post, handleExpectedResponse: null, cancellationToken: cancellationToken);
}

#if NET7_0_OR_GREATER
public override async Task<T> InvokeConnectionAsync<T>(string connectionId, string methodName, object?[] args, CancellationToken cancellationToken = default)
{
// Validate input parameters
if (string.IsNullOrEmpty(methodName))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(methodName));
}
if (string.IsNullOrEmpty(connectionId))
{
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
if (!_protocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

// Get API endpoint and prepare for the request
var api = _restApiProvider.SendClientInvocation(_appName, _hubName, connectionId);
InvocationResponse<T>? wrapper = null;
string? errorContent = null;
bool isSuccess = false;
// Send request and capture the response
await _restClient.SendMessageWithRetryAsync(
api,
HttpMethod.Post,
methodName,
args,
async response =>
{
isSuccess = response.IsSuccessStatusCode;

if (isSuccess)
{
await using var contentStream = await response.Content.ReadAsStreamAsync(cancellationToken);

var deserialized = await _restClient.ObjectSerializer.DeserializeAsync(
contentStream,
typeof(InvocationResponse<T>),
cancellationToken);

wrapper = deserialized as InvocationResponse<T>
?? throw new HubException("Failed to deserialize response");
}
else
{
errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
}

return isSuccess || response.StatusCode == HttpStatusCode.BadRequest;
},
cancellationToken);

// Ensure we have a response
if (!isSuccess)
{
throw new HubException(errorContent ?? "Unknown error in response");
}

return wrapper!.Result;
}

public override Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
{
// This method won't get trigger because in transient we will wait for the returned completion message.
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
case "messagepack":
return true;
default:
return false;
}
}

#endif

private static bool FilterExpectedResponse(HttpResponseMessage response, string expectedErrorCode) =>
response.IsSuccessStatusCode
|| (response.StatusCode == HttpStatusCode.NotFound && response.Headers.TryGetValues(Headers.MicrosoftErrorCode, out var errorCodes) && errorCodes.First().Equals(expectedErrorCode, StringComparison.OrdinalIgnoreCase));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft. All rights reserved.
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

#nullable enable
Expand Down Expand Up @@ -88,8 +88,7 @@ public bool IsVersionSupported(int version)

public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
//We don't need reading message with this protocol.
throw new NotSupportedException();
return new JsonHubProtocol().TryParseMessage(ref input, binder, out message!);
}

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#nullable enable
using System;
#if NET7_0_OR_GREATER
using System.Linq;
#endif
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -269,6 +272,11 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
throw new ArgumentNullException(nameof(methodName));
}

if (!ProtocolResolver.AllProtocols.All(IsInvocationSupported))
{
throw new NotSupportedException("Non supported protocol for client invocation.");
}

// cancellationToken is required to be cancellable.

using var cts = new CancellationTokenSource(DefaultInvocationTimeoutTimespan);
Expand Down Expand Up @@ -297,6 +305,19 @@ public override Task SetConnectionResultAsync(string connectionId, CompletionMes
// this is to honor the interface
throw new NotImplementedException();
}

private static bool IsInvocationSupported(IHubProtocol protocol)
{
// Use protocol.Name to check for supported protocols
switch (protocol.Name)
{
case "json":
case "messagepack":
return true;
default:
return false;
}
}
#endif

protected override T AppendMessageTracingId<T>(T message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ internal abstract class ServiceLifetimeManagerBase<THub> : HubLifetimeManager<TH
protected ILogger Logger { get; set; }

private readonly DefaultHubMessageSerializer _messageSerializer;
private readonly IHubProtocolResolver _protocolResolver;

public ServiceLifetimeManagerBase(IServiceConnectionManager<THub> serviceConnectionManager, IHubProtocolResolver protocolResolver, IOptions<HubOptions> globalHubOptions, IOptions<HubOptions<THub>> hubOptions, ILogger logger)
{
Logger = logger ?? throw new ArgumentNullException(nameof(logger));
ServiceConnectionContainer = serviceConnectionManager;
_messageSerializer = new DefaultHubMessageSerializer(protocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols);
_protocolResolver = protocolResolver;
}

public override Task OnConnectedAsync(HubConnectionContext connection)
Expand Down Expand Up @@ -326,6 +328,8 @@ protected virtual T AppendMessageTracingId<T>(T message) where T : ServiceMessag
return message.WithTracingId();
}

protected IHubProtocolResolver ProtocolResolver => _protocolResolver;

private async Task WriteCoreAsync<T>(T message, Func<T, Task> task) where T : ServiceMessage, IMessageWithTracingId
{
try
Expand Down
Loading
Loading