Skip to content

Commit 6414ca7

Browse files
committed
Add support for post connection subscriptions
1 parent 264aa4d commit 6414ca7

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

fastapi_websocket_pubsub/pub_sub_client.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ async def _primary_on_connect(self, channel: RpcChannel):
251251

252252
def subscribe(self, topic: Topic, callback: Coroutine):
253253
"""
254-
Subscribe for events (prior to starting the client)
254+
Subscribe for events (before and after starting the client)
255255
@see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe
256256
257257
Args:
@@ -260,18 +260,25 @@ def subscribe(self, topic: Topic, callback: Coroutine):
260260
'hello' or a complex path 'a/b/c/d' .
261261
Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics
262262
callback (Coroutine): the function to call upon relevant event publishing
263+
264+
Returns:
265+
Coroutine: awaitable task to subscribe to topic if connected.
263266
"""
264-
# TODO: add support for post connection subscriptions
265-
if not self.is_ready():
266-
self._topics.add(topic)
267-
# init to empty list if no entry
268-
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
269-
# add callback to callbacks list of the topic
270-
callbacks.append(callback)
267+
topic_is_new = topic not in self._topics
268+
self._topics.add(topic)
269+
# init to empty list if no entry
270+
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
271+
# add callback to callbacks list of the topic
272+
callbacks.append(callback)
273+
if topic_is_new and self.is_ready():
274+
return self._rpc_channel.other.subscribe(topics=[topic])
271275
else:
272-
raise PubSubClientInvalidStateException(
273-
"Client already connected and subscribed"
274-
)
276+
# If we can't return an RPC call future then we need
277+
# to supply something else to not fail when the
278+
# calling code awaits the result of this function.
279+
future = asyncio.Future()
280+
future.set_result(None)
281+
return future
275282

276283
async def publish(
277284
self, topics: TopicList, data=None, sync=True, notifier_id=None

0 commit comments

Comments
 (0)