Skip to content

Commit cd26aae

Browse files
committed
Fix some pyright strict errors
Refactor Method classes to be generic over ControllerType
1 parent bdab4fa commit cd26aae

File tree

19 files changed

+303
-189
lines changed

19 files changed

+303
-189
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"aioserial",
1515
"numpy",
1616
"pydantic",
17-
"pvi~=0.10.0",
17+
"pvi~=0.10.1",
1818
"pytango",
1919
"softioc",
2020
]

src/fastcs/attributes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AttrMode(Enum):
1818
class Sender(Protocol):
1919
"""Protocol for setting the value of an ``Attribute``."""
2020

21-
async def put(self, controller: Any, attr: AttrW, value: Any) -> None:
21+
async def put(self, controller: Any, attr: AttrW[T], value: Any) -> None:
2222
pass
2323

2424

@@ -28,7 +28,7 @@ class Updater(Protocol):
2828

2929
update_period: float
3030

31-
async def update(self, controller: Any, attr: AttrR) -> None:
31+
async def update(self, controller: Any, attr: AttrR[T]) -> None:
3232
pass
3333

3434

@@ -89,7 +89,7 @@ class AttrR(Attribute[T]):
8989
def __init__(
9090
self,
9191
datatype: DataType[T],
92-
access_mode=AttrMode.READ,
92+
access_mode: AttrMode = AttrMode.READ,
9393
group: str | None = None,
9494
handler: Updater | None = None,
9595
allowed_values: list[T] | None = None,
@@ -128,7 +128,7 @@ class AttrW(Attribute[T]):
128128
def __init__(
129129
self,
130130
datatype: DataType[T],
131-
access_mode=AttrMode.WRITE,
131+
access_mode: AttrMode = AttrMode.WRITE,
132132
group: str | None = None,
133133
handler: Sender | None = None,
134134
allowed_values: list[T] | None = None,
@@ -176,7 +176,7 @@ class AttrRW(AttrW[T], AttrR[T]):
176176
def __init__(
177177
self,
178178
datatype: DataType[T],
179-
access_mode=AttrMode.READ_WRITE,
179+
access_mode: AttrMode = AttrMode.READ_WRITE,
180180
group: str | None = None,
181181
handler: Handler | None = None,
182182
allowed_values: list[T] | None = None,

src/fastcs/backend.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import asyncio
22
from collections import defaultdict
3-
from collections.abc import Callable
3+
from collections.abc import Callable, Coroutine
44
from concurrent.futures import Future
5-
from types import MethodType
65

76
from softioc.asyncio_dispatcher import AsyncioDispatcher
87

8+
from fastcs.datatypes import T
9+
910
from .attributes import AttrR, AttrW, Sender, Updater
10-
from .controller import Controller
11+
from .controller import BaseController, Controller
1112
from .exceptions import FastCSException
1213
from .mapping import Mapping, SingleMapping
1314

15+
Callback = Callable[[], Coroutine[None, None, None]]
16+
1417

1518
class Backend:
1619
def __init__(
1720
self, controller: Controller, loop: asyncio.AbstractEventLoop | None = None
1821
):
1922
self._dispatcher = AsyncioDispatcher(loop)
20-
self._loop = self._dispatcher.loop
23+
self._loop: asyncio.AbstractEventLoop = self._dispatcher.loop # type: ignore
2124
self._controller = controller
2225

2326
self._initial_tasks = [controller.connect]
@@ -58,20 +61,20 @@ def _start_scan_tasks(self):
5861
for task in scan_tasks:
5962
asyncio.run_coroutine_threadsafe(task(), self._loop)
6063

61-
def _run(self):
64+
def _run(self) -> None:
6265
raise NotImplementedError("Specific Backend must implement _run")
6366

6467

65-
def _link_single_controller_put_tasks(single_mapping: SingleMapping) -> None:
66-
for name, method in single_mapping.put_methods.items():
68+
def _link_single_controller_put_tasks(
69+
single_mapping: SingleMapping,
70+
) -> None:
71+
for name, put in single_mapping.put_methods.items():
6772
name = name.removeprefix("put_")
6873

6974
attribute = single_mapping.attributes[name]
7075
match attribute:
7176
case AttrW():
72-
attribute.set_process_callback(
73-
MethodType(method.fn, single_mapping.controller)
74-
)
77+
attribute.set_process_callback(put)
7578
case _:
7679
raise FastCSException(
7780
f"Mode {attribute.access_mode} does not "
@@ -89,17 +92,28 @@ def _link_attribute_sender_class(single_mapping: SingleMapping) -> None:
8992

9093
callback = _create_sender_callback(attribute, single_mapping.controller)
9194
attribute.set_process_callback(callback)
95+
case _:
96+
pass
9297

9398

94-
def _create_sender_callback(attribute, controller):
95-
async def callback(value):
96-
await attribute.sender.put(controller, attribute, value)
99+
def _create_sender_callback(
100+
attribute: AttrW[T], controller: BaseController
101+
) -> Callable[[T], Coroutine[None, None, None]]:
102+
match attribute.sender:
103+
case Sender() as sender:
97104

98-
return callback
105+
async def put_callback(value: T):
106+
await sender.put(controller, attribute, value)
107+
case _:
108+
109+
async def put_callback(value: T):
110+
pass
99111

112+
return put_callback
100113

101-
def _get_scan_tasks(mapping: Mapping) -> list[Callable]:
102-
scan_dict: dict[float, list[Callable]] = defaultdict(list)
114+
115+
def _get_scan_tasks(mapping: Mapping) -> list[Callback]:
116+
scan_dict: dict[float, list[Callback]] = defaultdict(list)
103117

104118
for single_mapping in mapping.get_controller_mappings():
105119
_add_scan_method_tasks(scan_dict, single_mapping)
@@ -110,16 +124,15 @@ def _get_scan_tasks(mapping: Mapping) -> list[Callable]:
110124

111125

112126
def _add_scan_method_tasks(
113-
scan_dict: dict[float, list[Callable]], single_mapping: SingleMapping
127+
scan_dict: dict[float, list[Callback]], single_mapping: SingleMapping
114128
):
115-
for method in single_mapping.scan_methods.values():
116-
scan_dict[method.period].append(
117-
MethodType(method.fn, single_mapping.controller)
118-
)
129+
for scan in single_mapping.scan_methods.values():
130+
scan_dict[scan.period].append(scan)
119131

120132

121133
def _add_attribute_updater_tasks(
122-
scan_dict: dict[float, list[Callable]], single_mapping: SingleMapping
134+
scan_dict: dict[float, list[Callback]],
135+
single_mapping: SingleMapping,
123136
):
124137
for attribute in single_mapping.attributes.values():
125138
match attribute:
@@ -128,12 +141,20 @@ def _add_attribute_updater_tasks(
128141
attribute, single_mapping.controller
129142
)
130143
scan_dict[update_period].append(callback)
144+
case _:
145+
pass
131146

132147

133-
def _create_updater_callback(attribute, controller):
148+
def _create_updater_callback(
149+
attribute: AttrR[T], controller: BaseController
150+
) -> Callback:
134151
async def callback():
135152
try:
136-
await attribute.updater.update(controller, attribute)
153+
match attribute.updater:
154+
case Updater() as updater:
155+
await updater.update(controller, attribute)
156+
case _:
157+
pass
137158
except Exception as e:
138159
print(
139160
f"Update loop in {attribute.updater} stopped:\n"
@@ -144,15 +165,15 @@ async def callback():
144165
return callback
145166

146167

147-
def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
148-
periodic_scan_tasks: list[Callable] = []
168+
def _get_periodic_scan_tasks(scan_dict: dict[float, list[Callback]]) -> list[Callback]:
169+
periodic_scan_tasks: list[Callback] = []
149170
for period, methods in scan_dict.items():
150171
periodic_scan_tasks.append(_create_periodic_scan_task(period, methods))
151172

152173
return periodic_scan_tasks
153174

154175

155-
def _create_periodic_scan_task(period, methods: list[Callable]) -> Callable:
176+
def _create_periodic_scan_task(period: float, methods: list[Callback]) -> Callback:
156177
async def scan_task() -> None:
157178
while True:
158179
await asyncio.gather(*[method() for method in methods])

src/fastcs/backends/epics/gui.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
22
from enum import Enum
33
from pathlib import Path
4+
from typing import Any
45

56
from pvi._format.dls import DLSFormatter
67
from pvi.device import (
@@ -28,9 +29,9 @@
2829

2930
from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW
3031
from fastcs.cs_methods import Command
31-
from fastcs.datatypes import Bool, Float, Int, String
32+
from fastcs.datatypes import Bool, Float, Int, String, T
3233
from fastcs.exceptions import FastCSException
33-
from fastcs.mapping import Mapping, SingleMapping, _get_single_mapping
34+
from fastcs.mapping import Mapping, SingleMapping, get_single_mapping
3435
from fastcs.util import snake_to_pascal
3536

3637

@@ -56,7 +57,7 @@ def _get_pv(self, attr_path: list[str], name: str):
5657
return f"{attr_prefix}:{name.title().replace('_', '')}"
5758

5859
@staticmethod
59-
def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion:
60+
def _get_read_widget(attribute: AttrR[T]) -> ReadWidgetUnion:
6061
match attribute.datatype:
6162
case Bool():
6263
return LED()
@@ -68,10 +69,11 @@ def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion:
6869
raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}")
6970

7071
@staticmethod
71-
def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion:
72+
def _get_write_widget(attribute: AttrW[T]) -> WriteWidgetUnion:
7273
match attribute.allowed_values:
7374
case allowed_values if allowed_values is not None:
74-
return ComboBox(choices=allowed_values)
75+
choices = [v for v in allowed_values if isinstance(v, str)]
76+
return ComboBox(choices=choices)
7577
case _:
7678
pass
7779

@@ -86,7 +88,7 @@ def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion:
8688
raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}")
8789

8890
def _get_attribute_component(
89-
self, attr_path: list[str], name: str, attribute: Attribute
91+
self, attr_path: list[str], name: str, attribute: Attribute[Any]
9092
) -> SignalR | SignalW | SignalRW:
9193
pv = self._get_pv(attr_path, name)
9294
name = name.title().replace("_", "")
@@ -148,7 +150,7 @@ def extract_mapping_components(self, mapping: SingleMapping) -> Tree:
148150
name=snake_to_pascal(name),
149151
layout=SubScreen(),
150152
children=self.extract_mapping_components(
151-
_get_single_mapping(sub_controller)
153+
get_single_mapping(sub_controller)
152154
),
153155
)
154156
)

src/fastcs/backends/epics/ioc.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from collections.abc import Callable
1+
from collections.abc import Callable, Coroutine
22
from dataclasses import dataclass
3-
from types import MethodType
43
from typing import Any, Literal
54

65
from softioc import builder, softioc
@@ -156,6 +155,7 @@ def _create_and_link_attribute_pvs(pv_prefix: str, mapping: Mapping) -> None:
156155
def _create_and_link_read_pv(
157156
pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T]
158157
) -> None:
158+
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
159159
if attr_is_enum(attribute):
160160

161161
async def async_record_set(value: T):
@@ -165,13 +165,12 @@ async def async_record_set(value: T):
165165
async def async_record_set(value: T):
166166
record.set(value)
167167

168-
record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute)
169168
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")
170169

171170
attribute.set_update_callback(async_record_set)
172171

173172

174-
def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper:
173+
def _get_input_record(pv: str, attribute: AttrR[T]) -> RecordWrapper:
175174
if attr_is_enum(attribute):
176175
assert attribute.allowed_values is not None and all(
177176
isinstance(v, str) for v in attribute.allowed_values
@@ -199,31 +198,40 @@ def _create_and_link_write_pv(
199198
) -> None:
200199
if attr_is_enum(attribute):
201200

202-
async def on_update(value):
201+
async def on_update_enum(value: int):
203202
await attribute.process_without_display_update(
204203
enum_index_to_value(attribute, value)
205204
)
206205

206+
record = _get_output_record(
207+
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update_enum
208+
)
209+
207210
async def async_write_display(value: T):
208211
record.set(enum_value_to_index(attribute, value), process=False)
209212

210213
else:
211214

212-
async def on_update(value):
215+
async def on_update(value: T):
213216
await attribute.process_without_display_update(value)
214217

218+
record = _get_output_record(
219+
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update
220+
)
221+
215222
async def async_write_display(value: T):
216223
record.set(value, process=False)
217224

218-
record = _get_output_record(
219-
f"{pv_prefix}:{pv_name}", attribute, on_update=on_update
220-
)
221225
_add_attr_pvi_info(record, pv_prefix, attr_name, "w")
222226

223227
attribute.set_write_display_callback(async_write_display)
224228

225229

226-
def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
230+
def _get_output_record(
231+
pv: str,
232+
attribute: AttrW[T],
233+
on_update: Callable[..., Coroutine[None, None, None]],
234+
) -> Any:
227235
if attr_is_enum(attribute):
228236
assert attribute.allowed_values is not None and all(
229237
isinstance(v, str) for v in attribute.allowed_values
@@ -255,26 +263,24 @@ def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any:
255263
def _create_and_link_command_pvs(pv_prefix: str, mapping: Mapping) -> None:
256264
for single_mapping in mapping.get_controller_mappings():
257265
path = single_mapping.controller.path
258-
for attr_name, method in single_mapping.command_methods.items():
266+
for attr_name, command in single_mapping.command_methods.items():
259267
pv_name = attr_name.title().replace("_", "")
260268
_pv_prefix = ":".join([pv_prefix] + path)
261269
if len(f"{_pv_prefix}:{pv_name}") > EPICS_MAX_NAME_LENGTH:
262270
print(
263271
f"Not creating PV for {attr_name} as full name would exceed"
264272
f" {EPICS_MAX_NAME_LENGTH} characters"
265273
)
266-
method.enabled = False
274+
command.enabled = False
267275
else:
268-
_create_and_link_command_pv(
269-
_pv_prefix,
270-
pv_name,
271-
attr_name,
272-
MethodType(method.fn, single_mapping.controller),
273-
)
276+
_create_and_link_command_pv(_pv_prefix, pv_name, attr_name, command)
274277

275278

276279
def _create_and_link_command_pv(
277-
pv_prefix: str, pv_name: str, attr_name: str, method: Callable
280+
pv_prefix: str,
281+
pv_name: str,
282+
attr_name: str,
283+
method: Callable[[], Coroutine[None, None, None]],
278284
) -> None:
279285
async def wrapped_method(_: Any):
280286
await method()

src/fastcs/backends/epics/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES)
2626

2727

28-
def attr_is_enum(attribute: Attribute) -> bool:
28+
def attr_is_enum(attribute: Attribute[T]) -> bool:
2929
"""Check if the `Attribute` has a `String` datatype and has `allowed_values` set.
3030
3131
Args:

0 commit comments

Comments
 (0)