11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any
3+ from typing import TYPE_CHECKING , Any , Generic
4+
5+ from typing_extensions import TypeVar
46
57from .parameter import Parameter
8+ from .parameter_base import InstrumentType_co , ParameterDataTypeVar
69
710if TYPE_CHECKING :
811 from collections .abc import Sequence
912 from datetime import datetime
1013
14+ from qcodes .instrument import InstrumentBase
1115 from qcodes .validators .validators import Validator
1216
13- from .parameter_base import ParamDataType , ParamRawDataType
14-
15-
16- class DelegateParameter (Parameter ):
17+ from .parameter_base import (
18+ ParamDataType ,
19+ ParamRawDataType ,
20+ )
21+
22+ # Generic type variables for inner cache class
23+ # these need to be different variables such that both classes can be generic
24+ _local_ParameterDataTypeVar = TypeVar ("_local_ParameterDataTypeVar" , default = Any )
25+ _local_InstrumentType_co = TypeVar (
26+ "_local_InstrumentType_co" ,
27+ bound = "InstrumentBase | None" ,
28+ default = "InstrumentBase | None" ,
29+ covariant = True ,
30+ )
31+
32+
33+ class DelegateParameter (
34+ Parameter [ParameterDataTypeVar , InstrumentType_co ],
35+ Generic [ParameterDataTypeVar , InstrumentType_co ],
36+ ):
1737 """
1838 The :class:`.DelegateParameter` wraps a given `source` :class:`Parameter`.
1939 Setting/getting it results in a set/get of the source parameter with
@@ -51,8 +71,15 @@ class DelegateParameter(Parameter):
5171
5272 """
5373
54- class _DelegateCache :
55- def __init__ (self , parameter : DelegateParameter ):
74+ class _DelegateCache (
75+ Generic [_local_ParameterDataTypeVar , _local_InstrumentType_co ]
76+ ):
77+ def __init__ (
78+ self ,
79+ parameter : DelegateParameter [
80+ _local_ParameterDataTypeVar , _local_InstrumentType_co
81+ ],
82+ ):
5683 self ._parameter = parameter
5784 self ._marked_valid : bool = False
5885
@@ -99,7 +126,7 @@ def invalidate(self) -> None:
99126 if self ._parameter .source is not None :
100127 self ._parameter .source .cache .invalidate ()
101128
102- def get (self , get_if_invalid : bool = True ) -> ParamDataType :
129+ def get (self , get_if_invalid : bool = True ) -> _local_ParameterDataTypeVar :
103130 if self ._parameter .source is None :
104131 raise TypeError (
105132 "Cannot get the cache of a DelegateParameter that delegates to None"
@@ -108,7 +135,7 @@ def get(self, get_if_invalid: bool = True) -> ParamDataType:
108135 self ._parameter .source .cache .get (get_if_invalid = get_if_invalid )
109136 )
110137
111- def set (self , value : ParamDataType ) -> None :
138+ def set (self , value : _local_ParameterDataTypeVar ) -> None :
112139 if self ._parameter .source is None :
113140 raise TypeError (
114141 "Cannot set the cache of a DelegateParameter that delegates to None"
@@ -128,7 +155,7 @@ def _set_from_raw_value(self, raw_value: ParamRawDataType) -> None:
128155 def _update_with (
129156 self ,
130157 * ,
131- value : ParamDataType ,
158+ value : _local_ParameterDataTypeVar ,
132159 raw_value : ParamRawDataType ,
133160 timestamp : datetime | None = None ,
134161 ) -> None :
@@ -142,7 +169,7 @@ def _update_with(
142169 """
143170 pass
144171
145- def __call__ (self ) -> ParamDataType :
172+ def __call__ (self ) -> _local_ParameterDataTypeVar :
146173 return self .get (get_if_invalid = True )
147174
148175 def __init__ (
@@ -183,7 +210,7 @@ def __init__(
183210 # i.e. _SetParamContext overrides it
184211 self ._settable = True
185212
186- self .cache = self ._DelegateCache (self )
213+ self .cache = self ._DelegateCache [ ParameterDataTypeVar , InstrumentType_co ] (self )
187214 if initial_cache_value is not None :
188215 self .cache .set (initial_cache_value )
189216
0 commit comments