@@ -8,7 +8,6 @@ open Microsoft.Extensions.DependencyInjection
88open Microsoft.Extensions .Logging
99open System
1010open System.Collections .Concurrent
11- open System.Collections .Generic
1211open System.Net .WebSockets
1312open System.Threading
1413open System.Threading .Tasks
@@ -32,10 +31,10 @@ module Channels =
3231
3332 ///Type representing information about client that has executed some channel action
3433 ///It's passed as an argument in channel actions (`join`, `handle`, `terminate`)
35- type ClientInfo = { SocketId: SocketId }
34+ type ClientInfo = { SocketId: SocketId ; ChannelPath : ChannelPath }
3635 with
37- static member New socketId =
38- { SocketId = socketId }
36+ static member New channelPath socketId =
37+ { SocketId = socketId; ChannelPath = channelPath }
3938
4039 ///Type representing result of `join` action. It can be either succesful (`Ok`) or you can reject client connection (`Rejected`)
4140 type JoinResult =
@@ -53,11 +52,12 @@ module Channels =
5352 /// You can get instance of it with `ctx.GetService<Saturn.Channels.ISocketHub>()` from any place that has access to HttpContext instance (`controller` actions, `channel` actions, normal `HttpHandler`)
5453 type ISocketHub =
5554 abstract member SendMessageToClients: ChannelPath -> Topic -> 'a -> Task < unit >
56- abstract member SendMessageToClient: ChannelPath -> SocketId -> Topic -> 'a -> Task < unit >
55+ abstract member SendMessageToClient: SocketId -> Topic -> 'a -> Task < unit >
56+ abstract member SendMessageToClientsFilter: ( ClientInfo -> bool ) -> Topic -> 'a -> Task < unit >
5757
5858 /// A type that wraps access to connected websockets by endpoint
5959 type SocketHub ( serializer : IJsonSerializer ) =
60- let sockets = Dictionary < ChannelPath , ConcurrentDictionary< SocketId , Socket.ThreadSafeWebSocket> >()
60+ let sockets = ConcurrentDictionary< ClientInfo , Socket.ThreadSafeWebSocket>()
6161
6262 let sendMessage ( msg : 'a Message ) ( socket : Socket.ThreadSafeWebSocket ) = task {
6363 let text = serializer.SerializeToString msg
@@ -67,37 +67,48 @@ module Channels =
6767 | Error exn -> return exn.Throw()
6868 }
6969
70- member __.NewPath path =
71- match sockets.TryGetValue path with
72- | true , _ path -> ()
73- | false , _ -> sockets .[ path ] <- ConcurrentDictionary ()
70+ member __.ConnectSocketToPath path clientId socket =
71+ let ci = { SocketId = clientId ; ChannelPath = path }
72+ sockets.AddOrUpdate ( ci , socket , fun _ _ -> socket ) |> ignore
73+ ci
7474
75- member __.ConnectSocketToPath path id socket =
76- sockets.[ path]. AddOrUpdate( id, socket, fun _ _ -> socket) |> ignore
77- id
78-
79- member __.DisconnectSocketForPath path socketId =
80- sockets.[ path]. TryRemove socketId |> ignore
75+ member __.DisconnectSocketForPath path clientId =
76+ let ci = { SocketId = clientId; ChannelPath = path}
77+ sockets.TryRemove ci |> ignore
8178
8279 interface ISocketHub with
80+ member __.SendMessageToClientsFilter ( predicate : ClientInfo -> bool ) ( topic : Topic ) ( payload : 'a ): Task < unit > = task {
81+ let msg = { Topic = topic; Ref = " " ; Payload = payload }
82+ let tasks =
83+ sockets
84+ |> Seq.filter ( fun n -> predicate n.Key)
85+ |> Seq.map ( fun n -> sendMessage msg n.Value)
86+
87+ let! _results = Task.WhenAll tasks
88+ return ()
89+ }
90+
8391 member __.SendMessageToClients path topic payload = task {
8492 let msg = { Topic = topic; Ref = " " ; Payload = payload }
85- let tasks = [ for kvp in sockets.[ path] -> sendMessage msg kvp.Value ]
93+ let tasks =
94+ sockets
95+ |> Seq.filter ( fun n -> n.Key.ChannelPath = path)
96+ |> Seq.map ( fun n -> sendMessage msg n.Value)
97+
8698 let! _results = Task.WhenAll tasks
8799 return ()
88100 }
89101
90102 member __.SendMessageToClient path clientId topic payload = task {
91- match sockets.[ path]. TryGetValue clientId with
103+ let ci = { SocketId = clientId; ChannelPath = path}
104+ match sockets.TryGetValue ci with
92105 | true , socket ->
93106 let msg = { Topic = topic; Ref = " " ; Payload = payload }
94107 do ! sendMessage msg socket
95108 | _ -> ()
96109 }
97110
98111 type SocketMiddleware ( next : RequestDelegate , serializer : IJsonSerializer , path : string , channel : IChannel , sockets : SocketHub , logger : ILogger < SocketMiddleware >) =
99- do sockets.NewPath path
100-
101112 member __.Invoke ( ctx : HttpContext ) =
102113 task {
103114 if ctx.Request.Path = PathString( path) then
@@ -106,14 +117,14 @@ module Channels =
106117 let logger = ctx.RequestServices.GetRequiredService< ILogger< SocketMiddleware>>()
107118 logger.LogTrace( " Promoted websocket request" )
108119 let socketId = Guid.NewGuid()
109- let socketInfo = ClientInfo.New socketId
110- let! joinResult = channel.Join( ctx, socketInfo )
120+ let clientInfo = ClientInfo.New path socketId
121+ let! joinResult = channel.Join( ctx, clientInfo )
111122 match joinResult with
112123 | Ok ->
113124 logger.LogTrace( " Joined channel {path}" , path)
114125 let! webSocket = ctx.WebSockets.AcceptWebSocketAsync()
115126 let wrappedSocket = Socket.createFromWebSocket webSocket
116- let socketId = sockets.ConnectSocketToPath path socketId wrappedSocket
127+ let clientInfo = sockets.ConnectSocketToPath path socketId wrappedSocket
117128
118129 while wrappedSocket.State = WebSocketState.Open do
119130 match ! Socket.receiveMessageAsUTF8 wrappedSocket with
@@ -122,7 +133,7 @@ module Channels =
122133 | Result.Ok ( WebSocket.ReceiveUTF8Result.String msg) ->
123134 logger.LogTrace( " received message {0}" , msg)
124135 try
125- do ! channel.HandleMessage( ctx, socketInfo , serializer, msg)
136+ do ! channel.HandleMessage( ctx, clientInfo , serializer, msg)
126137 with
127138 | ex ->
128139 // typically a deserialization error, swallow
@@ -132,8 +143,8 @@ module Channels =
132143 logger.LogError( exn.SourceException, " Error while receiving message" )
133144 () // TODO: ?
134145
135- do ! channel.Terminate ( ctx, socketInfo )
136- sockets.DisconnectSocketForPath path socketId
146+ do ! channel.Terminate ( ctx, clientInfo )
147+ sockets.DisconnectSocketForPath path clientInfo.SocketId
137148 let! result = Socket.close wrappedSocket WebSocketCloseStatus.NormalClosure " Closing channel"
138149 match result with
139150 | Result.Ok () ->
0 commit comments