diff --git a/build/dependencies.props b/build/dependencies.props index acd008f..9da4e2a 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -5,7 +5,7 @@ 0.3.0 - 1.11.0-preview1-10835 + 1.11.0-preview1-10838 1.0.0 15.8.0 4.9.0 diff --git a/samples/bidirectional-chat/csharp/Function.cs b/samples/bidirectional-chat/csharp/Function.cs index 59242ea..151f323 100644 --- a/samples/bidirectional-chat/csharp/Function.cs +++ b/samples/bidirectional-chat/csharp/Function.cs @@ -14,26 +14,33 @@ namespace FunctionApp { - public class SimpleChat : ServerlessHub + public class SimpleChat : ServerlessHub { private const string NewMessageTarget = "newMessage"; private const string NewConnectionTarget = "newConnection"; + public interface IChatClient + { + public Task newConnection(NewConnection newConnection); + public Task newMessage(NewMessage newMessage); + } + [FunctionName("negotiate")] - public Task NegotiateAsync([HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequest req) + public async Task NegotiateAsync([HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequest req) { var claims = GetClaims(req.Headers["Authorization"]); - return NegotiateAsync(new NegotiationOptions + var result = await NegotiateAsync(new NegotiationOptions { UserId = claims.First(c => c.Type == ClaimTypes.NameIdentifier).Value, Claims = claims }); + return result; } [FunctionName(nameof(OnConnected))] public async Task OnConnected([SignalRTrigger]InvocationContext invocationContext, ILogger logger) { - await Clients.All.SendAsync(NewConnectionTarget, new NewConnection(invocationContext.ConnectionId)); + await Clients.All.newConnection(new NewConnection(invocationContext.ConnectionId)); logger.LogInformation($"{invocationContext.ConnectionId} has connected"); } @@ -41,26 +48,26 @@ public async Task OnConnected([SignalRTrigger]InvocationContext invocationContex [FunctionName(nameof(Broadcast))] public async Task Broadcast([SignalRTrigger]InvocationContext invocationContext, string message, ILogger logger) { - await Clients.All.SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + await Clients.All.newMessage(new NewMessage(invocationContext, message)); logger.LogInformation($"{invocationContext.ConnectionId} broadcast {message}"); } [FunctionName(nameof(SendToGroup))] public async Task SendToGroup([SignalRTrigger]InvocationContext invocationContext, string groupName, string message) { - await Clients.Group(groupName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + await Clients.Group(groupName).newMessage(new NewMessage(invocationContext, message)); } [FunctionName(nameof(SendToUser))] public async Task SendToUser([SignalRTrigger]InvocationContext invocationContext, string userName, string message) { - await Clients.User(userName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + await Clients.User(userName).newMessage(new NewMessage(invocationContext, message)); } [FunctionName(nameof(SendToConnection))] public async Task SendToConnection([SignalRTrigger]InvocationContext invocationContext, string connectionId, string message) { - await Clients.Client(connectionId).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + await Clients.Client(connectionId).newMessage(new NewMessage(invocationContext, message)); } [FunctionName(nameof(JoinGroup))] @@ -92,7 +99,7 @@ public void OnDisconnected([SignalRTrigger]InvocationContext invocationContext) { } - private class NewConnection + public class NewConnection { public string ConnectionId { get; } @@ -102,7 +109,7 @@ public NewConnection(string connectionId) } } - private class NewMessage + public class NewMessage { public string ConnectionId { get; } public string Sender { get; } diff --git a/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs index 70d441c..a1ba855 100644 --- a/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs +++ b/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs @@ -1,12 +1,17 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; +using System.Threading.Tasks; using Microsoft.Azure.SignalR; +using Microsoft.Azure.SignalR.Management; namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { internal interface IInternalServiceHubContextStore : IServiceHubContextStore { AccessKey[] AccessKeys { get; } + + public dynamic GetAsync(Type THubType, Type TType); } } \ No newline at end of file diff --git a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs index 54ca925..bc740a2 100644 --- a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs +++ b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs @@ -7,13 +7,16 @@ using System.Threading.Tasks; using Microsoft.Azure.SignalR; using Microsoft.Azure.SignalR.Management; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { internal class ServiceHubContextStore : IInternalServiceHubContextStore { - private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>(StringComparer.OrdinalIgnoreCase); + private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> _weakTypedHubStore = new(StringComparer.OrdinalIgnoreCase); private readonly IServiceEndpointManager endpointManager; + + private readonly IServiceProvider _strongTypedHubServiceProvider; public IServiceManager ServiceManager { get; } @@ -23,11 +26,15 @@ public ServiceHubContextStore(IServiceEndpointManager endpointManager, IServiceM { this.endpointManager = endpointManager; ServiceManager = serviceManager; + _strongTypedHubServiceProvider = new ServiceCollection() + .AddSingleton(serviceManager as ServiceManager) + .AddSingleton(typeof(ServerlessHubContext<,>)) + .BuildServiceProvider(); } public ValueTask GetAsync(string hubName) { - var pair = store.GetOrAdd(hubName, + var pair = _weakTypedHubStore.GetOrAdd(hubName, (new Lazy>( () => ServiceManager.CreateHubContextAsync(hubName)), default)); return GetAsyncCore(hubName, pair); @@ -50,14 +57,37 @@ private async Task GetFromLazyAsync(string hubName, (Lazy> GetAsync() where THub : ServerlessHub where T : class + { + return _strongTypedHubServiceProvider.GetRequiredService>().HubContextTask; + } + + /// + /// The method actually does the following thing + /// + /// private Task> GetAsync() where THub : ServerlessHub where T : class + ///{ + /// return _serviceProvider.GetRequiredService>().HubContext; + ///} + /// + /// + public dynamic GetAsync(Type THubType, Type TType) + { + var genericType = typeof(ServerlessHubContext<,>); + Type[] typeArgs = { THubType, TType }; + var serverlessHubContextType = genericType.MakeGenericType(typeArgs); + dynamic serverlessHubContext = _strongTypedHubServiceProvider.GetRequiredService(serverlessHubContextType); + return serverlessHubContext.HubContextTask.GetAwaiter().GetResult(); + } } } \ No newline at end of file diff --git a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj index 5af9d13..966a4bf 100644 --- a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj +++ b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj @@ -7,7 +7,10 @@ - + + + + diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs new file mode 100644 index 0000000..80d6b1c --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Management; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + //A helper class so that nameof(THub) together with T can be used as a key to retrieve a ServiceHubContext from a ServiceProvider. + internal class ServerlessHubContext where THub : ServerlessHub where T : class + { + public Task> HubContextTask { get; } + + public ServerlessHubContext(ServiceManager serviceManager) + { + HubContextTask = serviceManager.CreateHubContextAsync(typeof(THub).Name, default); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs new file mode 100644 index 0000000..a3a5063 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs @@ -0,0 +1,76 @@ +using System; +using System.Collections.Generic; +using System.IdentityModel.Tokens.Jwt; +using System.Linq; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Management; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + public abstract class ServerlessHub where T : class + { + private static readonly Lazy JwtSecurityTokenHandler = new Lazy(() => new JwtSecurityTokenHandler()); + protected ServiceHubContext HubContext { get; } + + public ServerlessHub(ServiceHubContext hubContext) + { + HubContext = hubContext; + } + + public ServerlessHub() + { + HubContext = (StaticServiceHubContextStore.Get() as IInternalServiceHubContextStore).GetAsync(GetType(), typeof(T)); + Clients = HubContext.Clients; + Groups = HubContext.Groups; + UserGroups = HubContext.UserGroups; + ClientManager = HubContext?.ClientManager; + } + + /// + /// Gets client endpoint access information object for SignalR hub connections to connect to Azure SignalR Service + /// + protected async ValueTask NegotiateAsync(NegotiationOptions options) + { + var negotiateResponse = await HubContext.NegotiateAsync(options); + return new SignalRConnectionInfo + { + Url = negotiateResponse.Url, + AccessToken = negotiateResponse.AccessToken + }; + } + + /// + /// Get claim list from a JWT. + /// + protected IList GetClaims(string jwt) + { + if (jwt.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + jwt = jwt.Substring("Bearer ".Length).Trim(); + } + return JwtSecurityTokenHandler.Value.ReadJwtToken(jwt).Claims.ToList(); + } + + /// + /// Gets an object that can be used to invoke methods on the clients connected to this hub. + /// + public IHubClients Clients { get; } + + /// + /// Get the group manager of this hub. + /// + public IGroupManager Groups { get; } + + /// + /// Get the user group manager of this hub. + /// + public IUserGroupManager UserGroups { get; } + + /// + /// Get the client manager of this hub. + /// + public ClientManager ClientManager { get; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs index 61aeaa5..234ac88 100644 --- a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs @@ -75,7 +75,7 @@ internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAtt var declaredType = method.DeclaringType; string[] parameterNamesFromAttribute; - if (declaredType != null && declaredType.IsSubclassOf(typeof(ServerlessHub))) + if (IsServerlessHub(declaredType)) { // Class based model if (!string.IsNullOrEmpty(hubName) || @@ -116,6 +116,28 @@ internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAtt return new SignalRTriggerAttribute(hubName, category, @event, parameterNames) { ConnectionStringSetting = connectionStringSetting }; } + private bool IsServerlessHub(Type type) + { + if (type == null) + { + return false; + } + if(type.IsSubclassOf(typeof(ServerlessHub))) + { + return true; + } + var baseType = type.BaseType; + while (baseType != null) + { + if(baseType.IsGenericType && baseType.GetGenericTypeDefinition() == typeof(ServerlessHub<>)) + { + return true; + } + baseType = baseType.BaseType; + } + return false; + } + private void ValidateSignalRTriggerAttributeBinding(SignalRTriggerAttribute attribute) { if (string.IsNullOrWhiteSpace(attribute.ConnectionStringSetting))