11from __future__ import annotations
22
33from datetime import datetime , timedelta
4- from typing import TYPE_CHECKING , Protocol
4+ from typing import TYPE_CHECKING , Any , Generic , Literal , Protocol , overload
5+
6+ from typing_extensions import TypeVar
7+
8+ # due to circular imports we cannot import the TypeVar from parameter_base
9+ ParameterDataTypeVar = TypeVar ("ParameterDataTypeVar" , default = Any )
510
611if TYPE_CHECKING :
7- from .parameter_base import ParamDataType , ParameterBase , ParamRawDataType
12+ from .parameter_base import (
13+ ParameterBase ,
14+ ParamRawDataType ,
15+ )
816
917
1018# The protocol is private to qcodes but used elsewhere in the codebase
11- class _CacheProtocol (Protocol ): # noqa: PYI046
19+ class _CacheProtocol (Protocol , Generic [ ParameterDataTypeVar ] ): # noqa: PYI046
1220 """
1321 This protocol defines the interface that a Parameter Cache implementation
1422 must implement. This is currently used for 2 implementations, one in
@@ -29,24 +37,36 @@ def valid(self) -> bool: ...
2937
3038 def invalidate (self ) -> None : ...
3139
32- def set (self , value : ParamDataType ) -> None : ...
40+ def set (self , value : ParameterDataTypeVar ) -> None : ...
3341
3442 def _set_from_raw_value (self , raw_value : ParamRawDataType ) -> None : ...
3543
36- def get (self , get_if_invalid : bool = True ) -> ParamDataType : ...
44+ @overload
45+ def get (self , get_if_invalid : Literal [True ]) -> ParameterDataTypeVar : ...
46+
47+ @overload
48+ def get (self ) -> ParameterDataTypeVar : ...
49+
50+ @overload
51+ def get (self , get_if_invalid : Literal [False ]) -> ParameterDataTypeVar | None : ...
52+
53+ @overload
54+ def get (self , get_if_invalid : bool ) -> ParameterDataTypeVar | None : ...
55+
56+ def get (self , get_if_invalid : bool = True ) -> ParameterDataTypeVar | None : ...
3757
3858 def _update_with (
3959 self ,
4060 * ,
41- value : ParamDataType ,
61+ value : ParameterDataTypeVar ,
4262 raw_value : ParamRawDataType ,
4363 timestamp : datetime | None = None ,
4464 ) -> None : ...
4565
46- def __call__ (self ) -> ParamDataType : ...
66+ def __call__ (self ) -> ParameterDataTypeVar : ...
4767
4868
49- class _Cache :
69+ class _Cache ( Generic [ ParameterDataTypeVar ]) :
5070 """
5171 Cache object for parameter to hold its value and raw value
5272
@@ -66,9 +86,11 @@ class _Cache:
6686
6787 """
6888
69- def __init__ (self , parameter : ParameterBase , max_val_age : float | None = None ):
89+ def __init__ (
90+ self , parameter : ParameterBase , max_val_age : float | None = None
91+ ) -> None :
7092 self ._parameter = parameter
71- self ._value : ParamDataType = None
93+ self ._value : ParameterDataTypeVar | None = None
7294 self ._raw_value : ParamRawDataType = None
7395 self ._timestamp : datetime | None = None
7496 self ._max_val_age = max_val_age
@@ -115,7 +137,7 @@ def invalidate(self) -> None:
115137 """
116138 self ._marked_valid = False
117139
118- def set (self , value : ParamDataType ) -> None :
140+ def set (self , value : ParameterDataTypeVar ) -> None :
119141 """
120142 Set the cached value of the parameter without invoking the
121143 ``set_cmd`` of the parameter (if it has one). For example, in case of
@@ -146,7 +168,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
146168 def _update_with (
147169 self ,
148170 * ,
149- value : ParamDataType ,
171+ value : ParameterDataTypeVar ,
150172 raw_value : ParamRawDataType ,
151173 timestamp : datetime | None = None ,
152174 ) -> None :
@@ -187,7 +209,19 @@ def _timestamp_expired(self) -> bool:
187209 # parameter is still valid
188210 return False
189211
190- def get (self , get_if_invalid : bool = True ) -> ParamDataType :
212+ @overload
213+ def get (self , get_if_invalid : Literal [True ]) -> ParameterDataTypeVar : ...
214+
215+ @overload
216+ def get (self ) -> ParameterDataTypeVar : ...
217+
218+ @overload
219+ def get (self , get_if_invalid : Literal [False ]) -> ParameterDataTypeVar | None : ...
220+
221+ @overload
222+ def get (self , get_if_invalid : bool ) -> ParameterDataTypeVar | None : ...
223+
224+ def get (self , get_if_invalid : bool = True ) -> ParameterDataTypeVar | None :
191225 """
192226 Return cached value if time since get was less than ``max_val_age``,
193227 or the parameter was explicitly marked invalid.
@@ -246,7 +280,7 @@ def _construct_error_msg(self) -> str:
246280 )
247281 return error_msg
248282
249- def __call__ (self ) -> ParamDataType :
283+ def __call__ (self ) -> ParameterDataTypeVar :
250284 """
251285 Same as :meth:`get` but always call ``get`` on parameter if the
252286 cache is not valid
0 commit comments