Skip to content

Commit b41cb4a

Browse files
authored
Merge pull request #7730 from jenshnielsen/generic_parameter
Generic parameter
2 parents 9dfc0a7 + 7025c17 commit b41cb4a

File tree

8 files changed

+215
-90
lines changed

8 files changed

+215
-90
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
The QCoDeS Parameter classes ``ParameterBase``, ``Parameter``, ``ParameterWithSetpoints``, ``DelegateParameter``, ``ArrayParameter`` and ``MultiParameter`` now
2+
takes two Optional Generic arguments to allow the data type and the type of the instrument the parameter is bound to to be fixed statically. This enables
3+
the type of the output of ``parameter.get()``, input of ``parameter.set()`` and value of ``parameter.instrument`` to be known statically such that type
4+
checkers and IDE's can make use of this information.

src/qcodes/parameters/array_parameter.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
has_loop = True
1313
except ImportError:
1414
has_loop = False
15+
from typing import Generic
1516

16-
from .parameter_base import ParameterBase
17+
from .parameter_base import InstrumentTypeVar_co, ParameterBase, ParameterDataTypeVar
1718
from .sequence_helpers import is_sequence_of
1819

1920
if TYPE_CHECKING:
2021
from collections.abc import Mapping, Sequence
2122

22-
from qcodes.instrument import InstrumentBase
23-
2423

2524
try:
2625
from qcodes_loop.data.data_array import DataArray
@@ -41,7 +40,10 @@
4140
)
4241

4342

44-
class ArrayParameter(ParameterBase):
43+
class ArrayParameter(
44+
ParameterBase[ParameterDataTypeVar, InstrumentTypeVar_co],
45+
Generic[ParameterDataTypeVar, InstrumentTypeVar_co],
46+
):
4547
"""
4648
A gettable parameter that returns an array of values.
4749
Not necessarily part of an instrument.
@@ -131,7 +133,9 @@ def __init__(
131133
self,
132134
name: str,
133135
shape: Sequence[int],
134-
instrument: InstrumentBase | None = None,
136+
# mypy seems to be confused here. The bound and default for InstrumentTypeVar_co
137+
# contains None but mypy will not allow it as a default as of v 1.19.0
138+
instrument: InstrumentTypeVar_co = None, # type: ignore[assignment]
135139
label: str | None = None,
136140
unit: str | None = None,
137141
setpoints: Sequence[Any] | None = None,

src/qcodes/parameters/cache.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from __future__ import annotations
22

33
from datetime import datetime, timedelta
4-
from typing import TYPE_CHECKING, Protocol
4+
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, overload
5+
6+
from typing_extensions import TypeVar
7+
8+
# due to circular imports we cannot import the TypeVar from parameter_base
9+
ParameterDataTypeVar = TypeVar("ParameterDataTypeVar", default=Any)
510

611
if TYPE_CHECKING:
7-
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
12+
from .parameter_base import (
13+
ParameterBase,
14+
ParamRawDataType,
15+
)
816

917

1018
# The protocol is private to qcodes but used elsewhere in the codebase
11-
class _CacheProtocol(Protocol): # noqa: PYI046
19+
class _CacheProtocol(Protocol, Generic[ParameterDataTypeVar]): # noqa: PYI046
1220
"""
1321
This protocol defines the interface that a Parameter Cache implementation
1422
must implement. This is currently used for 2 implementations, one in
@@ -29,24 +37,36 @@ def valid(self) -> bool: ...
2937

3038
def invalidate(self) -> None: ...
3139

32-
def set(self, value: ParamDataType) -> None: ...
40+
def set(self, value: ParameterDataTypeVar) -> None: ...
3341

3442
def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None: ...
3543

36-
def get(self, get_if_invalid: bool = True) -> ParamDataType: ...
44+
@overload
45+
def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ...
46+
47+
@overload
48+
def get(self) -> ParameterDataTypeVar: ...
49+
50+
@overload
51+
def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ...
52+
53+
@overload
54+
def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ...
55+
56+
def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None: ...
3757

3858
def _update_with(
3959
self,
4060
*,
41-
value: ParamDataType,
61+
value: ParameterDataTypeVar,
4262
raw_value: ParamRawDataType,
4363
timestamp: datetime | None = None,
4464
) -> None: ...
4565

46-
def __call__(self) -> ParamDataType: ...
66+
def __call__(self) -> ParameterDataTypeVar: ...
4767

4868

49-
class _Cache:
69+
class _Cache(Generic[ParameterDataTypeVar]):
5070
"""
5171
Cache object for parameter to hold its value and raw value
5272
@@ -66,9 +86,11 @@ class _Cache:
6686
6787
"""
6888

69-
def __init__(self, parameter: ParameterBase, max_val_age: float | None = None):
89+
def __init__(
90+
self, parameter: ParameterBase, max_val_age: float | None = None
91+
) -> None:
7092
self._parameter = parameter
71-
self._value: ParamDataType = None
93+
self._value: ParameterDataTypeVar | None = None
7294
self._raw_value: ParamRawDataType = None
7395
self._timestamp: datetime | None = None
7496
self._max_val_age = max_val_age
@@ -115,7 +137,7 @@ def invalidate(self) -> None:
115137
"""
116138
self._marked_valid = False
117139

118-
def set(self, value: ParamDataType) -> None:
140+
def set(self, value: ParameterDataTypeVar) -> None:
119141
"""
120142
Set the cached value of the parameter without invoking the
121143
``set_cmd`` of the parameter (if it has one). For example, in case of
@@ -146,7 +168,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
146168
def _update_with(
147169
self,
148170
*,
149-
value: ParamDataType,
171+
value: ParameterDataTypeVar,
150172
raw_value: ParamRawDataType,
151173
timestamp: datetime | None = None,
152174
) -> None:
@@ -187,7 +209,19 @@ def _timestamp_expired(self) -> bool:
187209
# parameter is still valid
188210
return False
189211

190-
def get(self, get_if_invalid: bool = True) -> ParamDataType:
212+
@overload
213+
def get(self, get_if_invalid: Literal[True]) -> ParameterDataTypeVar: ...
214+
215+
@overload
216+
def get(self) -> ParameterDataTypeVar: ...
217+
218+
@overload
219+
def get(self, get_if_invalid: Literal[False]) -> ParameterDataTypeVar | None: ...
220+
221+
@overload
222+
def get(self, get_if_invalid: bool) -> ParameterDataTypeVar | None: ...
223+
224+
def get(self, get_if_invalid: bool = True) -> ParameterDataTypeVar | None:
191225
"""
192226
Return cached value if time since get was less than ``max_val_age``,
193227
or the parameter was explicitly marked invalid.
@@ -246,7 +280,7 @@ def _construct_error_msg(self) -> str:
246280
)
247281
return error_msg
248282

249-
def __call__(self) -> ParamDataType:
283+
def __call__(self) -> ParameterDataTypeVar:
250284
"""
251285
Same as :meth:`get` but always call ``get`` on parameter if the
252286
cache is not valid

src/qcodes/parameters/delegate_parameter.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,39 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, Generic
4+
5+
from typing_extensions import TypeVar
46

57
from .parameter import Parameter
8+
from .parameter_base import InstrumentTypeVar_co, ParameterDataTypeVar
69

710
if TYPE_CHECKING:
811
from collections.abc import Sequence
912
from datetime import datetime
1013

14+
from qcodes.instrument import InstrumentBase
1115
from qcodes.validators.validators import Validator
1216

13-
from .parameter_base import ParamDataType, ParamRawDataType
14-
15-
16-
class DelegateParameter(Parameter):
17+
from .parameter_base import (
18+
ParamDataType,
19+
ParamRawDataType,
20+
)
21+
22+
# Generic type variables for inner cache class
23+
# these need to be different variables such that both classes can be generic
24+
_local_ParameterDataTypeVar = TypeVar("_local_ParameterDataTypeVar", default=Any)
25+
_local_InstrumentTypeVar_co = TypeVar(
26+
"_local_InstrumentTypeVar_co",
27+
bound="InstrumentBase | None",
28+
default="InstrumentBase | None",
29+
covariant=True,
30+
)
31+
32+
33+
class DelegateParameter(
34+
Parameter[ParameterDataTypeVar, InstrumentTypeVar_co],
35+
Generic[ParameterDataTypeVar, InstrumentTypeVar_co],
36+
):
1737
"""
1838
The :class:`.DelegateParameter` wraps a given `source` :class:`Parameter`.
1939
Setting/getting it results in a set/get of the source parameter with
@@ -51,8 +71,15 @@ class DelegateParameter(Parameter):
5171
5272
"""
5373

54-
class _DelegateCache:
55-
def __init__(self, parameter: DelegateParameter):
74+
class _DelegateCache(
75+
Generic[_local_ParameterDataTypeVar, _local_InstrumentTypeVar_co]
76+
):
77+
def __init__(
78+
self,
79+
parameter: DelegateParameter[
80+
_local_ParameterDataTypeVar, _local_InstrumentTypeVar_co
81+
],
82+
):
5683
self._parameter = parameter
5784
self._marked_valid: bool = False
5885

@@ -99,7 +126,7 @@ def invalidate(self) -> None:
99126
if self._parameter.source is not None:
100127
self._parameter.source.cache.invalidate()
101128

102-
def get(self, get_if_invalid: bool = True) -> ParamDataType:
129+
def get(self, get_if_invalid: bool = True) -> _local_ParameterDataTypeVar:
103130
if self._parameter.source is None:
104131
raise TypeError(
105132
"Cannot get the cache of a DelegateParameter that delegates to None"
@@ -108,7 +135,7 @@ def get(self, get_if_invalid: bool = True) -> ParamDataType:
108135
self._parameter.source.cache.get(get_if_invalid=get_if_invalid)
109136
)
110137

111-
def set(self, value: ParamDataType) -> None:
138+
def set(self, value: _local_ParameterDataTypeVar) -> None:
112139
if self._parameter.source is None:
113140
raise TypeError(
114141
"Cannot set the cache of a DelegateParameter that delegates to None"
@@ -128,7 +155,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
128155
def _update_with(
129156
self,
130157
*,
131-
value: ParamDataType,
158+
value: _local_ParameterDataTypeVar,
132159
raw_value: ParamRawDataType,
133160
timestamp: datetime | None = None,
134161
) -> None:
@@ -142,7 +169,7 @@ def _update_with(
142169
"""
143170
pass
144171

145-
def __call__(self) -> ParamDataType:
172+
def __call__(self) -> _local_ParameterDataTypeVar:
146173
return self.get(get_if_invalid=True)
147174

148175
def __init__(
@@ -183,7 +210,9 @@ def __init__(
183210
# i.e. _SetParamContext overrides it
184211
self._settable = True
185212

186-
self.cache = self._DelegateCache(self)
213+
self.cache = self._DelegateCache[ParameterDataTypeVar, InstrumentTypeVar_co](
214+
self
215+
)
187216
if initial_cache_value is not None:
188217
self.cache.set(initial_cache_value)
189218

src/qcodes/parameters/multi_parameter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22

33
import os
44
from collections.abc import Iterator, Mapping, Sequence
5-
from typing import TYPE_CHECKING, Any
5+
from typing import Any, Generic
66

77
import numpy as np
88

9-
from .parameter_base import ParameterBase
9+
from .parameter_base import InstrumentTypeVar_co, ParameterBase, ParameterDataTypeVar
1010
from .sequence_helpers import is_sequence_of
1111

12-
if TYPE_CHECKING:
13-
from qcodes.instrument import InstrumentBase
14-
1512
try:
1613
from qcodes_loop.data.data_array import DataArray
1714

@@ -50,7 +47,10 @@ def _is_nested_sequence_or_none(
5047
return True
5148

5249

53-
class MultiParameter(ParameterBase):
50+
class MultiParameter(
51+
ParameterBase[ParameterDataTypeVar, InstrumentTypeVar_co],
52+
Generic[ParameterDataTypeVar, InstrumentTypeVar_co],
53+
):
5454
"""
5555
A gettable parameter that returns multiple values with separate names,
5656
each of arbitrary shape. Not necessarily part of an instrument.
@@ -141,7 +141,9 @@ def __init__(
141141
name: str,
142142
names: Sequence[str],
143143
shapes: Sequence[Sequence[int]],
144-
instrument: InstrumentBase | None = None,
144+
# mypy seems to be confused here. The bound and default for InstrumentTypeVar_co
145+
# contains None but mypy will not allow it as a default as of v 1.19.0
146+
instrument: InstrumentTypeVar_co = None, # type: ignore[assignment]
145147
labels: Sequence[str] | None = None,
146148
units: Sequence[str] | None = None,
147149
setpoints: Sequence[Sequence[Any]] | None = None,

src/qcodes/parameters/parameter.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
import logging
77
import os
88
from types import MethodType
9-
from typing import TYPE_CHECKING, Any, Literal
9+
from typing import TYPE_CHECKING, Any, Generic, Literal
1010

1111
from .command import Command
12-
from .parameter_base import ParamDataType, ParameterBase, ParamRawDataType
12+
from .parameter_base import (
13+
InstrumentTypeVar_co,
14+
ParameterBase,
15+
ParameterDataTypeVar,
16+
ParamRawDataType,
17+
)
1318
from .sweep_values import SweepFixedValues
1419

1520
if TYPE_CHECKING:
@@ -24,7 +29,10 @@
2429
log = logging.getLogger(__name__)
2530

2631

27-
class Parameter(ParameterBase):
32+
class Parameter(
33+
ParameterBase[ParameterDataTypeVar, InstrumentTypeVar_co],
34+
Generic[ParameterDataTypeVar, InstrumentTypeVar_co],
35+
):
2836
"""
2937
A parameter represents a single degree of freedom. Most often,
3038
this is the standard parameter for Instruments, though it can also be
@@ -172,16 +180,18 @@ class Parameter(ParameterBase):
172180
def __init__(
173181
self,
174182
name: str,
175-
instrument: InstrumentBase | None = None,
183+
# mypy seems to be confused here. The bound and default for InstrumentTypeVar_co
184+
# contains None but mypy will not allow None as a default as of v 1.19.0
185+
instrument: InstrumentTypeVar_co = None, # type: ignore[assignment]
176186
label: str | None = None,
177187
unit: str | None = None,
178188
get_cmd: str | Callable[..., Any] | Literal[False] | None = None,
179189
set_cmd: str | Callable[..., Any] | Literal[False] | None = False,
180-
initial_value: float | str | None = None,
190+
initial_value: ParameterDataTypeVar | None = None,
181191
max_val_age: float | None = None,
182192
vals: Validator[Any] | None = None,
183193
docstring: str | None = None,
184-
initial_cache_value: float | str | None = None,
194+
initial_cache_value: ParameterDataTypeVar | None = None,
185195
bind_to_instrument: bool = True,
186196
**kwargs: Any,
187197
) -> None:
@@ -396,14 +406,16 @@ def __getitem__(self, keys: Any) -> SweepFixedValues:
396406
"""
397407
return SweepFixedValues(self, keys)
398408

399-
def increment(self, value: ParamDataType) -> None:
409+
def increment(self, value: ParameterDataTypeVar) -> None:
400410
"""Increment the parameter with a value
401411
402412
Args:
403413
value: Value to be added to the parameter.
404414
405415
"""
406-
self.set(self.get() + value)
416+
# this method only works with parameters that support addition
417+
# however we don't currently enforce that via typing
418+
self.set(self.get() + value) # type: ignore[operator]
407419

408420
def sweep(
409421
self,

0 commit comments

Comments
 (0)