Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ dependencies = [
"deepdiff",
"matplotlib",
"cachetools",
"bluesky-stomp >= 0.2.0",

#
# These dependencies may be issued as pre-release versions and should have a pin constraint
# as by default pip-install will not upgrade to a pre-release.
Expand Down
152 changes: 110 additions & 42 deletions src/mx_bluesky/hyperion/external_interaction/callbacks/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import logging
from collections.abc import Callable, Sequence
from abc import abstractmethod
from collections.abc import Callable
from contextlib import AbstractContextManager
from threading import Thread
from time import sleep # noqa
from urllib import request
from urllib.error import URLError

from blueapi.config import ApplicationConfig, ConfigLoader
from bluesky.callbacks import CallbackBase
from bluesky.callbacks.zmq import Proxy, RemoteDispatcher
from bluesky_stomp.messaging import StompClient
from bluesky_stomp.models import Broker
from dodal.log import LOGGER as DODAL_LOGGER
from dodal.log import set_up_all_logging_handlers

Expand Down Expand Up @@ -52,6 +57,9 @@
from mx_bluesky.hyperion.external_interaction.callbacks.snapshot_callback import (
BeamDrawingCallback,
)
from mx_bluesky.hyperion.external_interaction.callbacks.stomp.dispatcher import (
StompDispatcher,
)
from mx_bluesky.hyperion.parameters.cli import CallbackArgs, parse_callback_args
from mx_bluesky.hyperion.parameters.constants import CONST
from mx_bluesky.hyperion.parameters.gridscan import (
Expand Down Expand Up @@ -143,19 +151,6 @@ def log_debug(msg, *args, **kwargs):
NEXUS_LOGGER.debug(msg, *args, **kwargs)


def wait_for_threads_forever(threads: Sequence[Thread]):
alive = [t.is_alive() for t in threads]
try:
log_debug("Trying to wait forever on callback and dispatcher threads")
while all(alive):
sleep(LIVENESS_POLL_SECONDS)
alive = [t.is_alive() for t in threads]
except KeyboardInterrupt:
log_info("Main thread received interrupt - exiting.")
else:
log_info("Proxy or dispatcher thread ended - exiting.")


class HyperionCallbackRunner:
"""Runs Nexus, ISPyB and Zocalo callbacks in their own process."""

Expand All @@ -166,44 +161,26 @@ def __init__(self, callback_args: CallbackArgs) -> None:

self.callbacks = setup_callbacks()

self.proxy = Proxy(*CONST.CALLBACK_0MQ_PROXY_PORTS)
self.proxy_thread = Thread(
target=self.proxy.start, daemon=True, name="0MQ Proxy"
)

self.dispatcher = RemoteDispatcher(
f"localhost:{CONST.CALLBACK_0MQ_PROXY_PORTS[1]}"
)

def start_dispatcher(callbacks: list[Callable]):
for cb in callbacks:
self.dispatcher.subscribe(cb)
self.dispatcher.start()

self.dispatcher_thread = Thread(
target=start_dispatcher,
args=[self.callbacks],
daemon=True,
name="0MQ Dispatcher",
)

self.watchdog_thread = Thread(
target=run_watchdog,
daemon=True,
name="Watchdog",
args=[callback_args.watchdog_port],
)
log_info("Created 0MQ proxy and local RemoteDispatcher.")

self._dispatcher_cm: DispatcherContextMgr
if callback_args.stomp_config:
self._dispatcher_cm = StompDispatcherContextMgr(
callback_args, self.callbacks
)
else:
self._dispatcher_cm = RemoteDispatcherContextMgr(self.callbacks)

def start(self):
log_info(f"Launching threads, with callbacks: {self.callbacks}")
self.proxy_thread.start()
self.dispatcher_thread.start()
self.watchdog_thread.start()
log_info("Proxy and dispatcher thread launched.")
wait_for_threads_forever(
[self.proxy_thread, self.dispatcher_thread, self.watchdog_thread]
)
with self._dispatcher_cm:
ping_watchdog_while_alive(self._dispatcher_cm, self.watchdog_thread)


def run_watchdog(watchdog_port: int):
Expand Down Expand Up @@ -231,5 +208,96 @@ def main(dev_mode=False) -> None:
runner.start()


class DispatcherContextMgr(AbstractContextManager):
@abstractmethod
def is_alive(self) -> bool: ...


class RemoteDispatcherContextMgr(DispatcherContextMgr):
def __init__(self, callbacks: list[CallbackBase]):
super().__init__()

self.proxy = Proxy(*CONST.CALLBACK_0MQ_PROXY_PORTS)
self.proxy_thread = Thread(
target=self.proxy.start, daemon=True, name="0MQ Proxy"
)

self.dispatcher = RemoteDispatcher(
f"localhost:{CONST.CALLBACK_0MQ_PROXY_PORTS[1]}"
)

def start_dispatcher(callbacks: list[Callable]):
for cb in callbacks:
self.dispatcher.subscribe(cb)
self.dispatcher.start()

self.dispatcher_thread = Thread(
target=start_dispatcher,
args=[callbacks],
daemon=True,
name="0MQ Dispatcher",
)
log_info("Created 0MQ proxy and local RemoteDispatcher.")

def __enter__(self):
log_info("Proxy and dispatcher thread launched.")
self.proxy_thread.start()
self.dispatcher_thread.start()
return self

def __exit__(self, exc_type, exc_value, traceback, /):
self.dispatcher.stop()
# proxy has no way to stop

def is_alive(self):
return self.proxy_thread.is_alive() and self.dispatcher_thread.is_alive()


class StompDispatcherContextMgr(DispatcherContextMgr):
def __init__(self, args: CallbackArgs, callbacks: list[CallbackBase]):
super().__init__()
loader = ConfigLoader(ApplicationConfig)
loader.use_values_from_yaml(args.stomp_config)
config = loader.load()
log_info(
f"Stomp client configured on {config.stomp.url.host}:{config.stomp.url.port}"
)
self._stomp_client = StompClient.for_broker(
broker=Broker(
host=config.stomp.url.host,
port=config.stomp.url.port,
auth=config.stomp.auth,
)
)
self._dispatcher = StompDispatcher(self._stomp_client)
for cb in callbacks:
self._dispatcher.subscribe(cb)

def is_alive(self) -> bool:
return self._stomp_client.is_connected()

def __enter__(self):
self._dispatcher.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback, /):
self._dispatcher.__exit__(exc_type, exc_value, traceback)


def ping_watchdog_while_alive(
dispatcher_cm: DispatcherContextMgr, watchdog_thread: Thread
):
alive = watchdog_thread.is_alive() and dispatcher_cm.is_alive()
try:
log_debug("Trying to wait forever on callback and dispatcher threads")
while alive:
sleep(LIVENESS_POLL_SECONDS)
alive = watchdog_thread.is_alive() and dispatcher_cm.is_alive()
except KeyboardInterrupt:
log_info("Main thread received interrupt - exiting.")
else:
log_info("Proxy or dispatcher thread ended - exiting.")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from blueapi.client.event_bus import AnyEvent
from blueapi.core import DataEvent
from bluesky.run_engine import Dispatcher
from bluesky_stomp.messaging import MessageContext, StompClient
from bluesky_stomp.models import MessageTopic
from event_model import DocumentNames

from mx_bluesky.common.utils.log import ISPYB_ZOCALO_CALLBACK_LOGGER as LOGGER

BLUEAPI_EVENT_TOPIC = "public.worker.event"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should: Feels like we should get this out of blueapi somewhere, can you add an issue/PR to expose it in blueapi?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class StompDispatcher(Dispatcher):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could: This seems generic enough that it should live in stomp-bluesky can you make a PR in there for it? Or at least an issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an outstanding issue DiamondLightSource/bluesky-stomp#14 - I've referenced this PR in it.

Unfortunately this file would introduce a circular dependency between stomp-bluesky and blueapi if introduced there.

def __init__(self, stomp_client: StompClient):
super().__init__()
self._client = stomp_client

def __enter__(self):
self._subscription_id = self._client.subscribe(
MessageTopic(name=BLUEAPI_EVENT_TOPIC), self._on_event
)
LOGGER.info("Connecting to stomp broker...")
self._client.connect()

def __exit__(self, exc_type, exc_val, exc_tb):
LOGGER.info("Disconnecting from stomp and unsubscribing...")
self._client.disconnect()
self._client.unsubscribe(self._subscription_id)

def _on_event(self, event: AnyEvent, context: MessageContext):
match event:
case DataEvent(name=name, doc=doc): # type: ignore
self.process(DocumentNames[name], doc)
14 changes: 13 additions & 1 deletion src/mx_bluesky/hyperion/parameters/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from enum import StrEnum
from pathlib import Path

from pydantic.dataclasses import dataclass

Expand All @@ -25,6 +26,7 @@ class HyperionArgs:
class CallbackArgs:
dev_mode: bool = False
watchdog_port: int = HyperionConstants.HYPERION_PORT
stomp_config: Path | None = None


def _add_callback_relevant_args(parser: argparse.ArgumentParser) -> None:
Expand All @@ -45,8 +47,18 @@ def parse_callback_args() -> CallbackArgs:
type=int,
help="Liveness port for callbacks to ping regularly",
)
parser.add_argument(
"--stomp-config",
type=Path,
default=None,
help="Specify config yaml for the STOMP backend (default is 0MQ)",
)
args = parser.parse_args()
return CallbackArgs(dev_mode=args.dev, watchdog_port=args.watchdog_port)
return CallbackArgs(
dev_mode=args.dev,
watchdog_port=args.watchdog_port,
stomp_config=args.stomp_config,
)


def parse_cli_args() -> HyperionArgs:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import patch

from bluesky.callbacks import CallbackBase
from bluesky_stomp.models import MessageTopic
from event_model import Event, RunStart, RunStop

from mx_bluesky.hyperion.external_interaction.callbacks.__main__ import (
StompDispatcherContextMgr,
main,
)
from mx_bluesky.hyperion.parameters.cli import CallbackArgs

CALLBACK_TOPIC = "callback_test_events"


class PatchedStompCallbackMgr(StompDispatcherContextMgr):
def __init__(self, args: CallbackArgs, callbacks: list[CallbackBase]) -> None:
super().__init__(args, callbacks)
callback = callbacks[0]
assert isinstance(callback, StompTestCallback)
self.callback = callback

def __enter__(self):
super().__enter__()
self.callback.init(self._stomp_client)
return self


class StompTestCallback(CallbackBase):
def __init__(self) -> None:
super().__init__()
self.stomp_client = None
self.destination = MessageTopic(name=CALLBACK_TOPIC)

def init(self, stomp_client):
self.stomp_client = stomp_client
self.fire_event_back("init")

def start(self, doc: RunStart) -> RunStart | None:
self.fire_event_back(f"start: {doc['run_name']}") # type: ignore
return super().start(doc)

def stop(self, doc: RunStop) -> RunStop | None:
self.fire_event_back("stop")
return super().stop(doc)

def event(self, doc: Event) -> Event:
self.fire_event_back(f"event: {doc['data']['baton-requested_user']}")
return super().event(doc)

def fire_event_back(self, msg: str):
self.stomp_client.send(destination=self.destination, obj=msg) # type: ignore


if __name__ == "__main__":

def mock_setup_callbacks():
return [StompTestCallback()]

with (
patch(
"mx_bluesky.hyperion.external_interaction.callbacks.__main__.setup_callbacks",
return_value=mock_setup_callbacks(),
),
patch(
"mx_bluesky.hyperion.external_interaction.callbacks.__main__.StompDispatcherContextMgr",
PatchedStompCallbackMgr,
),
):
main()
Loading