Skip to content

Commit 9ebdf4d

Browse files
authored
Merge pull request #293 from DiamondLightSource/commands
Refactor Methods to allow overridding `@command` methods
2 parents d3a4858 + 3257e71 commit 9ebdf4d

File tree

10 files changed

+98
-63
lines changed

10 files changed

+98
-63
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989
("py:class", "p4p.nt.ndarray.NTNDArray"),
9090
("py:class", "p4p.nt.NTTable"),
9191
# Problems in FastCS itself
92-
("py:class", "T"),
92+
("py:class", "BaseController"),
9393
("py:class", "AttrIOUpdateCallback"),
9494
("py:class", "fastcs.transports.epics.pva.pvi_tree._PviSignalInfo"),
9595
("py:class", "fastcs.logging._logging.LogLevel"),

src/fastcs/control_system.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from fastcs.controllers import BaseController, Controller
1010
from fastcs.logging import bind_logger
11-
from fastcs.methods import Command, Scan, ScanCallback
11+
from fastcs.methods import ScanCallback
1212
from fastcs.tracer import Tracer
1313
from fastcs.transports import ControllerAPI, Transport
1414

@@ -174,23 +174,11 @@ def build_controller_api(controller: Controller) -> ControllerAPI:
174174

175175

176176
def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
177-
scan_methods: dict[str, Scan] = {}
178-
command_methods: dict[str, Command] = {}
179-
for attr_name in dir(controller):
180-
attr = getattr(controller, attr_name)
181-
match attr:
182-
case Scan(enabled=True):
183-
scan_methods[attr_name] = attr
184-
case Command(enabled=True):
185-
command_methods[attr_name] = attr
186-
case _:
187-
pass
188-
189177
return ControllerAPI(
190178
path=path,
191179
attributes=controller.attributes,
192-
scan_methods=scan_methods,
193-
command_methods=command_methods,
180+
command_methods=controller.command_methods,
181+
scan_methods=controller.scan_methods,
194182
sub_apis={
195183
name: _build_controller_api(sub_controller, path + [name])
196184
for name, sub_controller in controller.sub_controllers.items()

src/fastcs/controllers/base_controller.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute
99
from fastcs.logging import bind_logger
10+
from fastcs.methods import Command, Scan, UnboundCommand, UnboundScan
1011
from fastcs.tracer import Tracer
1112

1213
logger = bind_logger(logger_name=__name__)
@@ -46,6 +47,8 @@ def __init__(
4647
# Internal state that should not be accessed directly by base classes
4748
self.__attributes: dict[str, Attribute] = {}
4849
self.__sub_controllers: dict[str, BaseController] = {}
50+
self.__command_methods: dict[str, Command] = {}
51+
self.__scan_methods: dict[str, Scan] = {}
4952

5053
self.__hinted_attributes: dict[str, HintedAttribute] = {}
5154
self.__hinted_sub_controllers: dict[str, type[BaseController]] = {}
@@ -95,10 +98,6 @@ class method and a controller instance, so that it can be called from any
9598
context with the controller instance passed as the ``self`` argument.
9699
97100
"""
98-
# Lazy import to avoid circular references
99-
from fastcs.methods.command import UnboundCommand
100-
from fastcs.methods.scan import UnboundScan
101-
102101
# Using a dictionary instead of a set to maintain order.
103102
class_dir = {key: None for key in dir(type(self)) if not key.startswith("_")}
104103
class_type_hints = {
@@ -114,8 +113,21 @@ class method and a controller instance, so that it can be called from any
114113
attr = getattr(self, attr_name, None)
115114
if isinstance(attr, Attribute):
116115
setattr(self, attr_name, deepcopy(attr))
117-
elif isinstance(attr, UnboundScan | UnboundCommand):
118-
setattr(self, attr_name, attr.bind(self))
116+
else:
117+
if isinstance(attr, Command):
118+
self.add_command(attr_name, attr)
119+
elif isinstance(attr, Scan):
120+
self.add_scan(attr_name, attr)
121+
elif isinstance(
122+
unbound_command := getattr(attr, "__unbound_command__", None),
123+
UnboundCommand,
124+
):
125+
self.add_command(attr_name, unbound_command.bind(self))
126+
elif isinstance(
127+
unbound_scan := getattr(attr, "__unbound_scan__", None),
128+
UnboundScan,
129+
):
130+
self.add_scan(attr_name, unbound_scan.bind(self))
119131

120132
def _validate_io(self, ios: Sequence[AnyAttributeIO]):
121133
"""Validate that there is exactly one AttributeIO class registered to the
@@ -137,6 +149,10 @@ def __repr__(self):
137149
def __setattr__(self, name, value):
138150
if isinstance(value, Attribute):
139151
self.add_attribute(name, value)
152+
elif isinstance(value, Command):
153+
self.add_command(name, value)
154+
elif isinstance(value, Scan):
155+
self.add_scan(name, value)
140156
elif isinstance(value, BaseController):
141157
self.add_sub_controller(name, value)
142158
else:
@@ -300,3 +316,19 @@ def add_sub_controller(self, name: str, sub_controller: BaseController):
300316
@property
301317
def sub_controllers(self) -> dict[str, BaseController]:
302318
return self.__sub_controllers
319+
320+
def add_command(self, name: str, command: Command):
321+
self.__command_methods[name] = command
322+
super().__setattr__(name, command)
323+
324+
@property
325+
def command_methods(self) -> dict[str, Command]:
326+
return self.__command_methods
327+
328+
def add_scan(self, name: str, scan: Scan):
329+
self.__scan_methods[name] = scan
330+
super().__setattr__(name, scan)
331+
332+
@property
333+
def scan_methods(self) -> dict[str, Scan]:
334+
return self.__scan_methods

src/fastcs/methods/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .command import Command as Command
22
from .command import CommandCallback as CommandCallback
3+
from .command import UnboundCommand as UnboundCommand
34
from .command import command as command
45
from .scan import Scan as Scan
56
from .scan import ScanCallback as ScanCallback
7+
from .scan import UnboundScan as UnboundScan
68
from .scan import scan as scan

src/fastcs/methods/command.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from collections.abc import Callable, Coroutine
22
from types import MethodType
3+
from typing import TYPE_CHECKING
34

4-
from fastcs.controllers import BaseController
5+
from fastcs.logging import bind_logger
56
from fastcs.methods.method import Controller_T, Method
67

8+
if TYPE_CHECKING:
9+
from fastcs.controllers import BaseController # noqa: F401
10+
11+
logger = bind_logger(logger_name=__name__)
12+
713
UnboundCommandCallback = Callable[[Controller_T], Coroutine[None, None, None]]
814
"""A Command callback that is unbound and must be called with a `Controller` instance"""
915
CommandCallback = Callable[[], Coroutine[None, None, None]]
1016
"""A Command callback that is bound and can be called without `self`"""
1117

1218

13-
class Command(Method[BaseController]):
19+
class Command(Method["BaseController"]):
1420
"""A `Controller` `Method` that performs a single action when called.
1521
1622
This class contains a function that is bound to a specific `Controller` instance and
@@ -28,7 +34,18 @@ def _validate(self, fn: CommandCallback) -> None:
2834
raise TypeError(f"Command method cannot have arguments: {fn}")
2935

3036
async def __call__(self):
31-
return await self._fn()
37+
return await self.fn()
38+
39+
@property
40+
def fn(self) -> CommandCallback:
41+
async def command():
42+
try:
43+
return await self._fn()
44+
except Exception:
45+
logger.exception("Command failed", fn=self._fn)
46+
raise
47+
48+
return command
3249

3350

3451
class UnboundCommand(Method[Controller_T]):
@@ -56,15 +73,12 @@ def _validate(self, fn: UnboundCommandCallback[Controller_T]) -> None:
5673
def bind(self, controller: Controller_T) -> Command:
5774
return Command(MethodType(self.fn, controller), group=self.group)
5875

59-
def __call__(self):
60-
raise NotImplementedError(
61-
"Method must be bound to a controller instance to be callable"
62-
)
63-
6476

6577
def command(
6678
*, group: str | None = None
67-
) -> Callable[[UnboundCommandCallback[Controller_T]], UnboundCommand[Controller_T]]:
79+
) -> Callable[
80+
[UnboundCommandCallback[Controller_T]], UnboundCommandCallback[Controller_T]
81+
]:
6882
"""Decorator to register a `Controller` method as a `Command`
6983
7084
The `Command` will be passed to the transport layer to expose in the API
@@ -75,7 +89,9 @@ def command(
7589

7690
def wrapper(
7791
fn: UnboundCommandCallback[Controller_T],
78-
) -> UnboundCommand[Controller_T]:
79-
return UnboundCommand(fn, group=group)
92+
) -> UnboundCommandCallback[Controller_T]:
93+
setattr(fn, "__unbound_command__", UnboundCommand(fn, group=group)) # noqa: B010
94+
95+
return fn
8096

8197
return wrapper

src/fastcs/methods/method.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from asyncio import iscoroutinefunction
22
from collections.abc import Callable, Coroutine
33
from inspect import Signature, getdoc, signature
4-
from typing import Generic, TypeVar
4+
from typing import TYPE_CHECKING, Generic, TypeVar
55

6-
from fastcs.controllers.base_controller import BaseController
76
from fastcs.tracer import Tracer
87

8+
if TYPE_CHECKING:
9+
from fastcs.controllers import BaseController # noqa: F401
10+
911
MethodCallback = Callable[..., Coroutine[None, None, None]]
1012
"""Generic protocol for all `Controller` Method callbacks"""
11-
Controller_T = TypeVar("Controller_T", bound=BaseController)
13+
Controller_T = TypeVar("Controller_T", bound="BaseController") # noqa: F821
1214
"""Generic `Controller` class that an unbound method must be called with as `self`"""
1315

1416

src/fastcs/methods/scan.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from collections.abc import Callable, Coroutine
22
from types import MethodType
3+
from typing import TYPE_CHECKING
34

4-
from fastcs.controllers import BaseController
55
from fastcs.logging import bind_logger
66
from fastcs.methods.method import Controller_T, Method
77

8+
if TYPE_CHECKING:
9+
from fastcs.controllers import BaseController # noqa: F401
10+
811
logger = bind_logger(logger_name=__name__)
912

1013
UnboundScanCallback = Callable[[Controller_T], Coroutine[None, None, None]]
@@ -13,7 +16,7 @@
1316
"""A Scan callback that is bound and can be called without `self`"""
1417

1518

16-
class Scan(Method[BaseController]):
19+
class Scan(Method["BaseController"]):
1720
"""A `Controller` `Method` that will be called periodically in the background.
1821
1922
This class contains a function that is bound to a specific `Controller` instance and
@@ -40,7 +43,7 @@ async def __call__(self):
4043
return await self._fn()
4144

4245
@property
43-
def fn(self):
46+
def fn(self) -> ScanCallback:
4447
async def scan():
4548
try:
4649
return await self._fn()
@@ -80,15 +83,10 @@ def _validate(self, fn: UnboundScanCallback[Controller_T]) -> None:
8083
def bind(self, controller: Controller_T) -> Scan:
8184
return Scan(MethodType(self.fn, controller), self._period)
8285

83-
def __call__(self):
84-
raise NotImplementedError(
85-
"Method must be bound to a controller instance to be callable"
86-
)
87-
8886

8987
def scan(
9088
period: float,
91-
) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScan[Controller_T]]:
89+
) -> Callable[[UnboundScanCallback[Controller_T]], UnboundScanCallback[Controller_T]]:
9290
"""Decorator to register a `Controller` method as a `Scan`
9391
9492
The `Scan` method will be called periodically in the background.
@@ -97,7 +95,11 @@ def scan(
9795
if period <= 0:
9896
raise ValueError("Scan method must have a positive scan period")
9997

100-
def wrapper(fn: UnboundScanCallback[Controller_T]) -> UnboundScan[Controller_T]:
101-
return UnboundScan(fn, period)
98+
def wrapper(
99+
fn: UnboundScanCallback[Controller_T],
100+
) -> UnboundScanCallback[Controller_T]:
101+
setattr(fn, "__unbound_scan__", UnboundScan(fn, period=period)) # noqa: B010
102+
103+
return fn
102104

103105
return wrapper

src/fastcs/transports/controller_api.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _add_attribute_update_tasks(
8888

8989

9090
def _get_periodic_scan_coros(
91-
scan_dict: dict[float, list[Scan]],
91+
scan_dict: dict[float, list[ScanCallback]],
9292
) -> list[ScanCallback]:
9393
periodic_scan_coros: list[ScanCallback] = []
9494
for period, methods in scan_dict.items():
@@ -97,11 +97,13 @@ def _get_periodic_scan_coros(
9797
return periodic_scan_coros
9898

9999

100-
def _create_periodic_scan_coro(period: float, scans: list[Scan]) -> ScanCallback:
100+
def _create_periodic_scan_coro(
101+
period: float, scans: list[ScanCallback]
102+
) -> ScanCallback:
101103
async def _sleep():
102104
await asyncio.sleep(period)
103105

104-
methods = [_sleep] + scans # Create periodic behavior
106+
methods = [_sleep] + list(scans) # Create periodic behavior
105107

106108
async def scan_coro() -> None:
107109
while True:

tests/test_control_system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def test_scan_tasks(controller):
2222

2323
for _ in range(3):
2424
count = controller.count
25-
await asyncio.sleep(controller.counter.period + 0.01)
25+
await asyncio.sleep(0.1)
2626
assert controller.count > count
2727

2828

tests/test_methods.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,13 @@ async def do_nothing(self):
3939
async def do_nothing_with_arg(self, arg):
4040
pass
4141

42-
unbound_command = UnboundCommand(TestController.do_nothing, group="Test")
43-
44-
with pytest.raises(NotImplementedError):
45-
await unbound_command()
46-
4742
with pytest.raises(TypeError):
4843
UnboundCommand(TestController.do_nothing_with_arg) # type: ignore
4944

5045
with pytest.raises(TypeError):
5146
Command(TestController().do_nothing_with_arg) # type: ignore
5247

48+
unbound_command = UnboundCommand(TestController.do_nothing, group="Test")
5349
command = unbound_command.bind(TestController())
5450
# Test that group is passed when binding commands
5551
assert command.group == "Test"
@@ -66,19 +62,14 @@ async def update_nothing(self):
6662
async def update_nothing_with_arg(self, arg):
6763
pass
6864

69-
unbound_scan = UnboundScan(TestController.update_nothing, 1.0)
70-
71-
assert unbound_scan.period == 1.0
72-
73-
with pytest.raises(NotImplementedError):
74-
await unbound_scan()
75-
7665
with pytest.raises(TypeError):
7766
UnboundScan(TestController.update_nothing_with_arg, 1.0) # type: ignore
7867

7968
with pytest.raises(TypeError):
8069
Scan(TestController().update_nothing_with_arg, 1.0) # type: ignore
8170

71+
unbound_scan = UnboundScan(TestController.update_nothing, 1.0)
72+
assert unbound_scan.period == 1.0
8273
scan = unbound_scan.bind(TestController())
8374

8475
assert scan.period == 1.0

0 commit comments

Comments
 (0)