diff --git a/build/dependencies.props b/build/dependencies.props
index acd008f..9da4e2a 100644
--- a/build/dependencies.props
+++ b/build/dependencies.props
@@ -5,7 +5,7 @@
0.3.0
- 1.11.0-preview1-10835
+ 1.11.0-preview1-10838
1.0.0
15.8.0
4.9.0
diff --git a/samples/bidirectional-chat/csharp/Function.cs b/samples/bidirectional-chat/csharp/Function.cs
index 59242ea..151f323 100644
--- a/samples/bidirectional-chat/csharp/Function.cs
+++ b/samples/bidirectional-chat/csharp/Function.cs
@@ -14,26 +14,33 @@
namespace FunctionApp
{
- public class SimpleChat : ServerlessHub
+ public class SimpleChat : ServerlessHub
{
private const string NewMessageTarget = "newMessage";
private const string NewConnectionTarget = "newConnection";
+ public interface IChatClient
+ {
+ public Task newConnection(NewConnection newConnection);
+ public Task newMessage(NewMessage newMessage);
+ }
+
[FunctionName("negotiate")]
- public Task NegotiateAsync([HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequest req)
+ public async Task NegotiateAsync([HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequest req)
{
var claims = GetClaims(req.Headers["Authorization"]);
- return NegotiateAsync(new NegotiationOptions
+ var result = await NegotiateAsync(new NegotiationOptions
{
UserId = claims.First(c => c.Type == ClaimTypes.NameIdentifier).Value,
Claims = claims
});
+ return result;
}
[FunctionName(nameof(OnConnected))]
public async Task OnConnected([SignalRTrigger]InvocationContext invocationContext, ILogger logger)
{
- await Clients.All.SendAsync(NewConnectionTarget, new NewConnection(invocationContext.ConnectionId));
+ await Clients.All.newConnection(new NewConnection(invocationContext.ConnectionId));
logger.LogInformation($"{invocationContext.ConnectionId} has connected");
}
@@ -41,26 +48,26 @@ public async Task OnConnected([SignalRTrigger]InvocationContext invocationContex
[FunctionName(nameof(Broadcast))]
public async Task Broadcast([SignalRTrigger]InvocationContext invocationContext, string message, ILogger logger)
{
- await Clients.All.SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ await Clients.All.newMessage(new NewMessage(invocationContext, message));
logger.LogInformation($"{invocationContext.ConnectionId} broadcast {message}");
}
[FunctionName(nameof(SendToGroup))]
public async Task SendToGroup([SignalRTrigger]InvocationContext invocationContext, string groupName, string message)
{
- await Clients.Group(groupName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ await Clients.Group(groupName).newMessage(new NewMessage(invocationContext, message));
}
[FunctionName(nameof(SendToUser))]
public async Task SendToUser([SignalRTrigger]InvocationContext invocationContext, string userName, string message)
{
- await Clients.User(userName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ await Clients.User(userName).newMessage(new NewMessage(invocationContext, message));
}
[FunctionName(nameof(SendToConnection))]
public async Task SendToConnection([SignalRTrigger]InvocationContext invocationContext, string connectionId, string message)
{
- await Clients.Client(connectionId).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ await Clients.Client(connectionId).newMessage(new NewMessage(invocationContext, message));
}
[FunctionName(nameof(JoinGroup))]
@@ -92,7 +99,7 @@ public void OnDisconnected([SignalRTrigger]InvocationContext invocationContext)
{
}
- private class NewConnection
+ public class NewConnection
{
public string ConnectionId { get; }
@@ -102,7 +109,7 @@ public NewConnection(string connectionId)
}
}
- private class NewMessage
+ public class NewMessage
{
public string ConnectionId { get; }
public string Sender { get; }
diff --git a/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs
index 70d441c..a1ba855 100644
--- a/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs
+++ b/src/SignalRServiceExtension/Config/IInternalServiceHubContextStore.cs
@@ -1,12 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+using System;
+using System.Threading.Tasks;
using Microsoft.Azure.SignalR;
+using Microsoft.Azure.SignalR.Management;
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
internal interface IInternalServiceHubContextStore : IServiceHubContextStore
{
AccessKey[] AccessKeys { get; }
+
+ public dynamic GetAsync(Type THubType, Type TType);
}
}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
index 54ca925..bc740a2 100644
--- a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
+++ b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
@@ -7,13 +7,16 @@
using System.Threading.Tasks;
using Microsoft.Azure.SignalR;
using Microsoft.Azure.SignalR.Management;
+using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
internal class ServiceHubContextStore : IInternalServiceHubContextStore
{
- private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>(StringComparer.OrdinalIgnoreCase);
+ private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> _weakTypedHubStore = new(StringComparer.OrdinalIgnoreCase);
private readonly IServiceEndpointManager endpointManager;
+
+ private readonly IServiceProvider _strongTypedHubServiceProvider;
public IServiceManager ServiceManager { get; }
@@ -23,11 +26,15 @@ public ServiceHubContextStore(IServiceEndpointManager endpointManager, IServiceM
{
this.endpointManager = endpointManager;
ServiceManager = serviceManager;
+ _strongTypedHubServiceProvider = new ServiceCollection()
+ .AddSingleton(serviceManager as ServiceManager)
+ .AddSingleton(typeof(ServerlessHubContext<,>))
+ .BuildServiceProvider();
}
public ValueTask GetAsync(string hubName)
{
- var pair = store.GetOrAdd(hubName,
+ var pair = _weakTypedHubStore.GetOrAdd(hubName,
(new Lazy>(
() => ServiceManager.CreateHubContextAsync(hubName)), default));
return GetAsyncCore(hubName, pair);
@@ -50,14 +57,37 @@ private async Task GetFromLazyAsync(string hubName, (Lazy> GetAsync() where THub : ServerlessHub where T : class
+ {
+ return _strongTypedHubServiceProvider.GetRequiredService>().HubContextTask;
+ }
+
+ ///
+ /// The method actually does the following thing
+ ///
+ /// private Task> GetAsync() where THub : ServerlessHub where T : class
+ ///{
+ /// return _serviceProvider.GetRequiredService>().HubContext;
+ ///}
+ ///
+ ///
+ public dynamic GetAsync(Type THubType, Type TType)
+ {
+ var genericType = typeof(ServerlessHubContext<,>);
+ Type[] typeArgs = { THubType, TType };
+ var serverlessHubContextType = genericType.MakeGenericType(typeArgs);
+ dynamic serverlessHubContext = _strongTypedHubServiceProvider.GetRequiredService(serverlessHubContextType);
+ return serverlessHubContext.HubContextTask.GetAwaiter().GetResult();
+ }
}
}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
index 5af9d13..966a4bf 100644
--- a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
+++ b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
@@ -7,7 +7,10 @@
-
+
+
+
+
diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs
new file mode 100644
index 0000000..80d6b1c
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHubContext`T.cs
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Threading.Tasks;
+using Microsoft.Azure.SignalR.Management;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ //A helper class so that nameof(THub) together with T can be used as a key to retrieve a ServiceHubContext from a ServiceProvider.
+ internal class ServerlessHubContext where THub : ServerlessHub where T : class
+ {
+ public Task> HubContextTask { get; }
+
+ public ServerlessHubContext(ServiceManager serviceManager)
+ {
+ HubContextTask = serviceManager.CreateHubContextAsync(typeof(THub).Name, default);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs
new file mode 100644
index 0000000..a3a5063
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub`T.cs
@@ -0,0 +1,76 @@
+using System;
+using System.Collections.Generic;
+using System.IdentityModel.Tokens.Jwt;
+using System.Linq;
+using System.Security.Claims;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR;
+using Microsoft.Azure.SignalR.Management;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ public abstract class ServerlessHub where T : class
+ {
+ private static readonly Lazy JwtSecurityTokenHandler = new Lazy(() => new JwtSecurityTokenHandler());
+ protected ServiceHubContext HubContext { get; }
+
+ public ServerlessHub(ServiceHubContext hubContext)
+ {
+ HubContext = hubContext;
+ }
+
+ public ServerlessHub()
+ {
+ HubContext = (StaticServiceHubContextStore.Get() as IInternalServiceHubContextStore).GetAsync(GetType(), typeof(T));
+ Clients = HubContext.Clients;
+ Groups = HubContext.Groups;
+ UserGroups = HubContext.UserGroups;
+ ClientManager = HubContext?.ClientManager;
+ }
+
+ ///
+ /// Gets client endpoint access information object for SignalR hub connections to connect to Azure SignalR Service
+ ///
+ protected async ValueTask NegotiateAsync(NegotiationOptions options)
+ {
+ var negotiateResponse = await HubContext.NegotiateAsync(options);
+ return new SignalRConnectionInfo
+ {
+ Url = negotiateResponse.Url,
+ AccessToken = negotiateResponse.AccessToken
+ };
+ }
+
+ ///
+ /// Get claim list from a JWT.
+ ///
+ protected IList GetClaims(string jwt)
+ {
+ if (jwt.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
+ {
+ jwt = jwt.Substring("Bearer ".Length).Trim();
+ }
+ return JwtSecurityTokenHandler.Value.ReadJwtToken(jwt).Claims.ToList();
+ }
+
+ ///
+ /// Gets an object that can be used to invoke methods on the clients connected to this hub.
+ ///
+ public IHubClients Clients { get; }
+
+ ///
+ /// Get the group manager of this hub.
+ ///
+ public IGroupManager Groups { get; }
+
+ ///
+ /// Get the user group manager of this hub.
+ ///
+ public IUserGroupManager UserGroups { get; }
+
+ ///
+ /// Get the client manager of this hub.
+ ///
+ public ClientManager ClientManager { get; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs
index 61aeaa5..234ac88 100644
--- a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs
@@ -75,7 +75,7 @@ internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAtt
var declaredType = method.DeclaringType;
string[] parameterNamesFromAttribute;
- if (declaredType != null && declaredType.IsSubclassOf(typeof(ServerlessHub)))
+ if (IsServerlessHub(declaredType))
{
// Class based model
if (!string.IsNullOrEmpty(hubName) ||
@@ -116,6 +116,28 @@ internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAtt
return new SignalRTriggerAttribute(hubName, category, @event, parameterNames) { ConnectionStringSetting = connectionStringSetting };
}
+ private bool IsServerlessHub(Type type)
+ {
+ if (type == null)
+ {
+ return false;
+ }
+ if(type.IsSubclassOf(typeof(ServerlessHub)))
+ {
+ return true;
+ }
+ var baseType = type.BaseType;
+ while (baseType != null)
+ {
+ if(baseType.IsGenericType && baseType.GetGenericTypeDefinition() == typeof(ServerlessHub<>))
+ {
+ return true;
+ }
+ baseType = baseType.BaseType;
+ }
+ return false;
+ }
+
private void ValidateSignalRTriggerAttributeBinding(SignalRTriggerAttribute attribute)
{
if (string.IsNullOrWhiteSpace(attribute.ConnectionStringSetting))