Skip to content

Commit 932c0c3

Browse files
committed
Use ParamSpec when passing through __call__
1 parent b687c5c commit 932c0c3

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/dodal/device_manager.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
V2 = TypeVar("V2", bound=OphydV2Device)
4242

4343
DeviceFactoryDecorator = Callable[[Callable[Args, V2]], "DeviceFactory[Args, V2]"]
44-
OphydInitialiser = Callable[Concatenate[V1, ...], V1 | None]
44+
OphydInitialiser = Callable[Concatenate[V1, Args], V1 | None]
4545

4646
_EMPTY = object()
4747
"""Sentinel value to distinguish between missing values and present but null values"""
@@ -166,11 +166,11 @@ def build(
166166
device.set_name(name)
167167
return device # type: ignore - it's us, honest
168168

169-
def create(self, *args, **kwargs) -> V2:
169+
def create(self, *args: Args.args, **kwargs: Args.kwargs) -> V2:
170170
# TODO: Remove when v1 support is no longer required - see #1718
171171
return self(*args, **kwargs)
172172

173-
def __call__(self, *args, **kwargs) -> V2:
173+
def __call__(self, *args: Args.args, **kwargs: Args.kwargs) -> V2:
174174
device = self.factory(*args, **kwargs)
175175
if self.use_factory_name:
176176
device.set_name(self.name)
@@ -182,7 +182,7 @@ def __repr__(self) -> str:
182182

183183

184184
# TODO: Remove when ophyd v1 support is no longer required - see #1718
185-
class V1DeviceFactory(Generic[V1]):
185+
class V1DeviceFactory(Generic[Args, V1]):
186186
"""
187187
Wrapper around an ophyd v1 device that holds a reference to a device
188188
manager that can provide dependencies, along with default connection
@@ -198,7 +198,7 @@ def __init__(
198198
skip: SkipType,
199199
wait: bool,
200200
timeout: int,
201-
init: OphydInitialiser[V1],
201+
init: OphydInitialiser[V1, Args],
202202
manager: "DeviceManager",
203203
):
204204
self.factory = factory
@@ -255,11 +255,11 @@ def mock_if_needed(self, mock=False) -> Self:
255255
manager=self._manager,
256256
)
257257

258-
def __call__(self, *args, **kwargs):
258+
def __call__(self, dev: V1, *args: Args.args, **kwargs: Args.kwargs):
259259
"""Call the wrapped function to make decorator transparent"""
260-
return self.post_create(*args, **kwargs)
260+
return self.post_create(dev, *args, **kwargs)
261261

262-
def create(self, *args, **kwargs) -> V1:
262+
def create(self, *args: Args.args, **kwargs: Args.kwargs) -> V1:
263263
device = self.factory(name=self.name, prefix=self.prefix)
264264
if self.wait:
265265
wait_for_connection(device, timeout=self.timeout)
@@ -384,7 +384,7 @@ def v1_init(
384384
and is not used to create the device.
385385
"""
386386

387-
def decorator(init: OphydInitialiser[V1]) -> V1DeviceFactory[V1]:
387+
def decorator(init: OphydInitialiser[V1, Args]) -> V1DeviceFactory[Args, V1]:
388388
name = init.__name__
389389
if name in self:
390390
raise ValueError(f"Duplicate factory name: {name}")
@@ -534,7 +534,7 @@ def __getitem__(self, name):
534534

535535
def _expand_dependencies(
536536
self,
537-
factories: Iterable[DeviceFactory[..., V2] | V1DeviceFactory[V1]],
537+
factories: Iterable[DeviceFactory[..., V2] | V1DeviceFactory[..., V1]],
538538
available_fixtures: Mapping[str, Any],
539539
) -> set[str]:
540540
"""
@@ -566,7 +566,7 @@ def _expand_dependencies(
566566

567567
def _build_order(
568568
self,
569-
factories: dict[str, DeviceFactory[..., V2] | V1DeviceFactory[V1]],
569+
factories: dict[str, DeviceFactory[..., V2] | V1DeviceFactory[..., V1]],
570570
fixtures: Mapping[str, Any],
571571
) -> list[str]:
572572
"""

0 commit comments

Comments
 (0)