Skip to content

Commit 812c427

Browse files
rtuck99DominicOramDiamondJoseph
authored
feat: Support for Device Composite injection (#1231)
This introduces support for composite devices in blueapi plan signatures. Composite devices must extend from pydantic `BaseModel`, and the devices contained in them will be obtained from the `BlueskyContext` --------- Co-authored-by: Dominic Oram <[email protected]> Co-authored-by: Joseph Ware <[email protected]>
1 parent d894018 commit 812c427

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

src/blueapi/core/context.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Callable
33
from dataclasses import InitVar, dataclass, field
44
from importlib import import_module
5-
from inspect import Parameter, signature
5+
from inspect import Parameter, isclass, signature
66
from types import ModuleType, NoneType, UnionType
77
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
88

@@ -11,7 +11,7 @@
1111
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
1212
from dodal.utils import AnyDevice, make_all_devices
1313
from ophyd_async.core import NotConnected
14-
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
14+
from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
1515
from pydantic.fields import FieldInfo
1616
from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema
1717
from pydantic_core import CoreSchema, core_schema
@@ -82,6 +82,9 @@ def is_bluesky_type(typ: type) -> bool:
8282
return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS)
8383

8484

85+
C = TypeVar("C", bound=BaseModel, covariant=True)
86+
87+
8588
@dataclass
8689
class BlueskyContext:
8790
"""
@@ -383,7 +386,14 @@ def _type_spec_for_function(
383386
)
384387

385388
no_default = para.default is Parameter.empty
386-
factory = None if no_default else DefaultFactory(para.default)
389+
default_factory = (
390+
self._composite_factory(arg_type)
391+
if isclass(arg_type)
392+
and issubclass(arg_type, BaseModel)
393+
and isinstance(para.default, str)
394+
else DefaultFactory(para.default)
395+
)
396+
factory = None if no_default else default_factory
387397
new_args[name] = (
388398
self._convert_type(arg_type, no_default),
389399
FieldInfo(default_factory=factory),
@@ -419,6 +429,20 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type:
419429
return root[new_types] if root else typ # type: ignore
420430
return typ
421431

432+
def _composite_factory(self, composite_class: type[C]) -> Callable[[], C]:
433+
def _inject_composite():
434+
devices = {
435+
field: self.find_device(info.default)
436+
if info.annotation is not None
437+
and is_bluesky_type(info.annotation)
438+
and isinstance(info.default, str)
439+
else info.default
440+
for field, info in composite_class.model_fields.items()
441+
}
442+
return composite_class(**devices)
443+
444+
return _inject_composite
445+
422446

423447
D = TypeVar("D")
424448

tests/unit_tests/worker/test_task_worker.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from blueapi.config import EnvironmentConfig, Source, SourceKind
2222
from blueapi.core import BlueskyContext, EventStream
2323
from blueapi.core.bluesky_types import DataEvent
24+
from blueapi.utils.base_model import BlueapiBaseModel
2425
from blueapi.worker import (
2526
Task,
2627
TaskStatus,
@@ -53,12 +54,9 @@
5354
class FakeDevice(Movable[float]):
5455
event: threading.Event
5556

56-
@property
57-
def name(self) -> str:
58-
return "fake_device"
59-
60-
def __init__(self) -> None:
57+
def __init__(self, name: str = "fake_device") -> None:
6158
self.event = threading.Event()
59+
self.name = name
6260

6361
def set(self, value: float) -> Status:
6462
def when_done(_: Status):
@@ -84,14 +82,31 @@ def fake_device() -> FakeDevice:
8482

8583

8684
@pytest.fixture
87-
def context(fake_device: FakeDevice) -> BlueskyContext:
85+
def second_fake_device() -> FakeDevice:
86+
return FakeDevice("second_fake_device")
87+
88+
89+
@pytest.fixture
90+
def context(fake_device: FakeDevice, second_fake_device: FakeDevice) -> BlueskyContext:
8891
ctx = BlueskyContext()
8992
ctx_config = EnvironmentConfig()
9093
ctx_config.sources.append(
9194
Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices")
9295
)
9396
ctx.register_plan(failing_plan)
9497
ctx.register_device(fake_device)
98+
ctx.register_device(second_fake_device)
99+
ctx.with_config(ctx_config)
100+
return ctx
101+
102+
103+
@pytest.fixture
104+
def context_without_devices() -> BlueskyContext:
105+
ctx = BlueskyContext()
106+
ctx_config = EnvironmentConfig()
107+
ctx_config.sources.append(
108+
Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices")
109+
)
95110
ctx.with_config(ctx_config)
96111
return ctx
97112

@@ -681,3 +696,38 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker)
681696
task.is_complete = False
682697
task.is_pending = True
683698
mock_logger.info.assert_called_with(f"Got new task: {task}")
699+
700+
701+
class MyComposite(BlueapiBaseModel):
702+
dev_a: FakeDevice = inject(fake_device.name)
703+
dev_b: FakeDevice = inject(second_fake_device.name)
704+
705+
model_config = {"arbitrary_types_allowed": True}
706+
707+
708+
def injected_device_plan(composite: MyComposite = inject("")) -> MsgGenerator:
709+
yield from ()
710+
711+
712+
def test_injected_composite_devices_are_found(
713+
fake_device: FakeDevice,
714+
second_fake_device: FakeDevice,
715+
context: BlueskyContext,
716+
):
717+
context.register_plan(injected_device_plan)
718+
params = Task(name="injected_device_plan").prepare_params(context)
719+
assert params["composite"].dev_a == fake_device
720+
assert params["composite"].dev_b == second_fake_device
721+
722+
723+
def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
724+
context_without_devices: BlueskyContext,
725+
fake_device: FakeDevice,
726+
second_fake_device: FakeDevice,
727+
):
728+
context_without_devices.register_plan(injected_device_plan)
729+
context_without_devices.register_device(fake_device)
730+
context_without_devices.register_device(second_fake_device)
731+
params = Task(name="injected_device_plan").prepare_params(context_without_devices)
732+
assert params["composite"].dev_a == fake_device
733+
assert params["composite"].dev_b == second_fake_device

0 commit comments

Comments
 (0)