Skip to content

Commit 3da35f0

Browse files
authored
Merge pull request #240 from DiamondLightSource/split-launch
Split launch.py into control_system.py and controller_api.py
2 parents 7d9dcb1 + d69bd70 commit 3da35f0

File tree

15 files changed

+481
-479
lines changed

15 files changed

+481
-479
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
("py:class", "fastcs.logging._graylog.GraylogEndpoint"),
9696
("py:class", "fastcs.logging._graylog.GraylogStaticFields"),
9797
("py:class", "fastcs.logging._graylog.GraylogEnvFields"),
98-
("py:obj", "fastcs.launch.build_controller_api"),
98+
("py:obj", "fastcs.control_system.build_controller_api"),
9999
("py:obj", "fastcs.transport.epics.util.controller_pv_prefix"),
100100
("docutils", "fastcs.demo.controllers.TemperatureControllerSettings"),
101101
# TypeVar without docstrings still give warnings

src/fastcs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
from . import datatypes as datatypes
1313
from . import transport as transport
1414
from ._version import __version__ as __version__
15-
from .launch import FastCS as FastCS
15+
from .control_system import FastCS as FastCS

src/fastcs/control_system.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import asyncio
2+
import signal
3+
from collections.abc import Coroutine, Sequence
4+
from functools import partial
5+
from typing import Any
6+
7+
from IPython.terminal.embed import InteractiveShellEmbed
8+
9+
from fastcs.controller import BaseController, Controller
10+
from fastcs.controller_api import ControllerAPI
11+
from fastcs.cs_methods import Command, Put, Scan
12+
from fastcs.exceptions import FastCSError
13+
from fastcs.logging import logger as _fastcs_logger
14+
from fastcs.tracer import Tracer
15+
from fastcs.transport import Transport
16+
from fastcs.util import validate_hinted_attributes
17+
18+
tracer = Tracer(name=__name__)
19+
logger = _fastcs_logger.bind(logger_name=__name__)
20+
21+
22+
class FastCS:
23+
"""Entrypoint for a FastCS application.
24+
25+
This class takes a ``Controller``, creates asyncio tasks to run its update loops and
26+
builds its API to serve over the given transports.
27+
28+
:param: controller: The controller to serve in the control system
29+
:param: transports: A list of transports to serve the API over
30+
:param: loop: Optional event loop to run the control system in
31+
"""
32+
33+
def __init__(
34+
self,
35+
controller: Controller,
36+
transports: Sequence[Transport],
37+
loop: asyncio.AbstractEventLoop | None = None,
38+
):
39+
self._loop = loop or asyncio.get_event_loop()
40+
self._controller = controller
41+
42+
self._scan_tasks: set[asyncio.Task] = set()
43+
44+
# these initialise the controller & build its APIs
45+
self._loop.run_until_complete(controller.initialise())
46+
self._loop.run_until_complete(controller.attribute_initialise())
47+
validate_hinted_attributes(controller)
48+
self.controller_api = build_controller_api(controller)
49+
self._link_process_tasks()
50+
51+
self._scan_coros, self._initial_coros = (
52+
self.controller_api.get_scan_and_initial_coros()
53+
)
54+
self._initial_coros.append(controller.connect)
55+
56+
self._transports = transports
57+
for transport in self._transports:
58+
transport.initialise(controller_api=self.controller_api, loop=self._loop)
59+
60+
def create_docs(self) -> None:
61+
for transport in self._transports:
62+
transport.create_docs()
63+
64+
def create_gui(self) -> None:
65+
for transport in self._transports:
66+
transport.create_gui()
67+
68+
def run(self):
69+
serve = asyncio.ensure_future(self.serve())
70+
71+
self._loop.add_signal_handler(signal.SIGINT, serve.cancel)
72+
self._loop.add_signal_handler(signal.SIGTERM, serve.cancel)
73+
self._loop.run_until_complete(serve)
74+
75+
def _link_process_tasks(self):
76+
for controller_api in self.controller_api.walk_api():
77+
controller_api.link_put_tasks()
78+
79+
async def _run_initial_coros(self):
80+
for coro in self._initial_coros:
81+
await coro()
82+
83+
async def _start_scan_tasks(self):
84+
self._scan_tasks = {self._loop.create_task(coro()) for coro in self._scan_coros}
85+
86+
for task in self._scan_tasks:
87+
task.add_done_callback(self._scan_done)
88+
89+
def _scan_done(self, task: asyncio.Task):
90+
try:
91+
task.result()
92+
except Exception as e:
93+
raise FastCSError(
94+
"Exception raised in scan method of "
95+
f"{self._controller.__class__.__name__}"
96+
) from e
97+
98+
def _stop_scan_tasks(self):
99+
for task in self._scan_tasks:
100+
if not task.done():
101+
try:
102+
task.cancel()
103+
except (asyncio.CancelledError, RuntimeError):
104+
pass
105+
except Exception as e:
106+
raise RuntimeError("Unhandled exception in stop scan tasks") from e
107+
108+
async def serve(self) -> None:
109+
context = {
110+
"controller": self._controller,
111+
"controller_api": self.controller_api,
112+
"transports": [
113+
transport.__class__.__name__ for transport in self._transports
114+
],
115+
}
116+
117+
coros = []
118+
for transport in self._transports:
119+
coros.append(transport.serve())
120+
common_context = context.keys() & transport.context.keys()
121+
if common_context:
122+
raise RuntimeError(
123+
"Duplicate context keys found between "
124+
f"current context { ({k: context[k] for k in common_context}) } "
125+
f"and {transport.__class__.__name__} context: "
126+
f"{ ({k: transport.context[k] for k in common_context}) }"
127+
)
128+
context.update(transport.context)
129+
130+
coros.append(self._interactive_shell(context))
131+
132+
logger.info(
133+
"Starting FastCS",
134+
controller=self._controller,
135+
transports=f"[{', '.join(str(t) for t in self._transports)}]",
136+
)
137+
138+
await self._run_initial_coros()
139+
await self._start_scan_tasks()
140+
141+
try:
142+
await asyncio.gather(*coros)
143+
except asyncio.CancelledError:
144+
pass
145+
except Exception as e:
146+
raise RuntimeError("Unhandled exception in serve") from e
147+
148+
async def _interactive_shell(self, context: dict[str, Any]):
149+
"""Spawn interactive shell in another thread and wait for it to complete."""
150+
151+
def run(coro: Coroutine[None, None, None]):
152+
"""Run coroutine on FastCS event loop from IPython thread."""
153+
154+
def wrapper():
155+
asyncio.create_task(coro)
156+
157+
self._loop.call_soon_threadsafe(wrapper)
158+
159+
async def interactive_shell(
160+
context: dict[str, object], stop_event: asyncio.Event
161+
):
162+
"""Run interactive shell in a new thread."""
163+
shell = InteractiveShellEmbed()
164+
await asyncio.to_thread(partial(shell.mainloop, local_ns=context))
165+
166+
stop_event.set()
167+
168+
context["run"] = run
169+
170+
stop_event = asyncio.Event()
171+
self._loop.create_task(interactive_shell(context, stop_event))
172+
await stop_event.wait()
173+
174+
def __del__(self):
175+
self._stop_scan_tasks()
176+
177+
178+
def build_controller_api(controller: Controller) -> ControllerAPI:
179+
return _build_controller_api(controller, [])
180+
181+
182+
def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
183+
scan_methods: dict[str, Scan] = {}
184+
put_methods: dict[str, Put] = {}
185+
command_methods: dict[str, Command] = {}
186+
for attr_name in dir(controller):
187+
attr = getattr(controller, attr_name)
188+
match attr:
189+
case Put(enabled=True):
190+
put_methods[attr_name] = attr
191+
case Scan(enabled=True):
192+
scan_methods[attr_name] = attr
193+
case Command(enabled=True):
194+
command_methods[attr_name] = attr
195+
case _:
196+
pass
197+
198+
return ControllerAPI(
199+
path=path,
200+
attributes=controller.attributes,
201+
scan_methods=scan_methods,
202+
put_methods=put_methods,
203+
command_methods=command_methods,
204+
sub_apis={
205+
name: _build_controller_api(sub_controller, path + [name])
206+
for name, sub_controller in controller.get_sub_controllers().items()
207+
},
208+
description=controller.description,
209+
)

src/fastcs/controller_api.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1-
from collections.abc import Iterator
1+
import asyncio
2+
from collections import defaultdict
3+
from collections.abc import Callable, Iterator
24
from dataclasses import dataclass, field
35

4-
from fastcs.attributes import Attribute
6+
from fastcs.attribute_io_ref import AttributeIORef
7+
from fastcs.attributes import ONCE, Attribute, AttrR, AttrW
58
from fastcs.cs_methods import Command, Put, Scan
9+
from fastcs.exceptions import FastCSError
10+
from fastcs.logging import logger as _fastcs_logger
11+
from fastcs.tracer import Tracer
12+
13+
tracer = Tracer(name=__name__)
14+
logger = _fastcs_logger.bind(logger_name=__name__)
615

716

817
@dataclass
@@ -34,3 +43,72 @@ def __repr__(self):
3443
return f"""\
3544
ControllerAPI(path={self.path}, sub_apis=[{", ".join(self.sub_apis.keys())}])\
3645
"""
46+
47+
def link_put_tasks(self) -> None:
48+
for name, method in self.put_methods.items():
49+
name = name.removeprefix("put_")
50+
51+
attribute = self.attributes[name]
52+
match attribute:
53+
case AttrW():
54+
attribute.set_on_put_callback(method.fn)
55+
case _:
56+
raise FastCSError(
57+
f"Attribute type {type(attribute)} does not"
58+
f"support put operations for {name}"
59+
)
60+
61+
def get_scan_and_initial_coros(self) -> tuple[list[Callable], list[Callable]]:
62+
scan_dict: dict[float, list[Callable]] = defaultdict(list)
63+
initial_coros: list[Callable] = []
64+
65+
for controller_api in self.walk_api():
66+
_add_scan_method_tasks(scan_dict, controller_api)
67+
_add_attribute_update_tasks(scan_dict, initial_coros, controller_api)
68+
69+
scan_coros = _get_periodic_scan_coros(scan_dict)
70+
return scan_coros, initial_coros
71+
72+
73+
def _add_scan_method_tasks(
74+
scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI
75+
):
76+
for method in controller_api.scan_methods.values():
77+
scan_dict[method.period].append(method.fn)
78+
79+
80+
def _add_attribute_update_tasks(
81+
scan_dict: dict[float, list[Callable]],
82+
initial_coros: list[Callable],
83+
controller_api: ControllerAPI,
84+
):
85+
for attribute in controller_api.attributes.values():
86+
match attribute:
87+
case (
88+
AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute
89+
):
90+
if update_period is ONCE:
91+
initial_coros.append(attribute.bind_update_callback())
92+
elif update_period is not None:
93+
scan_dict[update_period].append(attribute.bind_update_callback())
94+
95+
96+
def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
97+
periodic_scan_coros: list[Callable] = []
98+
for period, methods in scan_dict.items():
99+
periodic_scan_coros.append(_create_periodic_scan_coro(period, methods))
100+
101+
return periodic_scan_coros
102+
103+
104+
def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable:
105+
async def _sleep():
106+
await asyncio.sleep(period)
107+
108+
methods.append(_sleep) # Create periodic behavior
109+
110+
async def scan_coro() -> None:
111+
while True:
112+
await asyncio.gather(*[method() for method in methods])
113+
114+
return scan_coro

0 commit comments

Comments
 (0)