@@ -30,6 +30,29 @@ class _Selected:
3030 inner : Optional [Any ]
3131
3232
33+ @dataclass
34+ class _ReadyReceiver :
35+ """A class for tracking receivers that have a message ready to be read.
36+
37+ Used to make sure that receivers are not consumed from until messages are accessed
38+ by user code, at which point, it will be converted into a `_Selected` object.
39+
40+ When a channel has closed, `recv` should be `None`.
41+ """
42+
43+ recv : Optional [Receiver [Any ]]
44+
45+ def get (self ) -> _Selected :
46+ """Consume a message from the receiver and return a `_Selected` object.
47+
48+ Returns:
49+ An instance of `_Selected` holding a value from the receiver.
50+ """
51+ if self .recv is None :
52+ return _Selected (None )
53+ return _Selected (self .recv ._get ()) # pylint: disable=protected-access
54+
55+
3356class Select :
3457 """Select the next available message from a group of Receivers.
3558
@@ -67,16 +90,16 @@ def __init__(self, **kwargs: Receiver[Any]) -> None:
6790 **kwargs: sequence of receivers
6891 """
6992 self ._receivers = kwargs
70- self ._pending : Set [asyncio .Task [Any ]] = set ()
93+ self ._pending : Set [asyncio .Task [None ]] = set ()
7194
7295 for name , recv in self ._receivers .items ():
7396 # can replace __anext__() to anext() (Only Python 3.10>=)
74- msg = recv .__anext__ () # pylint: disable=unnecessary-dunder-call
75- self ._pending .add (asyncio .create_task (msg , name = name )) # type: ignore
97+ ready = recv ._ready () # pylint: disable=unnecessary-dunder-call
98+ self ._pending .add (asyncio .create_task (ready , name = name ))
7699
77100 self ._ready_count = 0
78101 self ._prev_ready_count = 0
79- self ._result : Dict [str , Optional [_Selected ]] = {
102+ self ._result : Dict [str , Optional [_ReadyReceiver ]] = {
80103 name : None for name in self ._receivers
81104 }
82105
@@ -100,6 +123,8 @@ async def ready(self) -> bool:
100123 for name , value in self ._result .items ():
101124 if value is not None :
102125 dropped_names .append (name )
126+ if value .recv is not None :
127+ value .recv ._get () # pylint: disable=protected-access
103128 self ._result [name ] = None
104129 self ._ready_count = 0
105130 self ._prev_ready_count = 0
@@ -123,20 +148,19 @@ async def ready(self) -> bool:
123148 )
124149 for item in done :
125150 name = item .get_name ()
151+ recv = self ._receivers [name ]
126152 if isinstance (item .exception (), StopAsyncIteration ):
127153 result = None
128154 else :
129- result = item . result ()
155+ result = recv
130156 self ._ready_count += 1
131- self ._result [name ] = _Selected (result )
157+ self ._result [name ] = _ReadyReceiver (result )
132158 # if channel or Receiver is closed
133159 # don't add a task for it again.
134160 if result is None :
135161 continue
136- msg = self ._receivers [ # pylint: disable=unnecessary-dunder-call
137- name
138- ].__anext__ ()
139- self ._pending .add (asyncio .create_task (msg , name = name )) # type: ignore
162+ ready = recv ._ready () # pylint: disable=protected-access
163+ self ._pending .add (asyncio .create_task (ready , name = name ))
140164 return True
141165
142166 def __getattr__ (self , name : str ) -> Optional [Any ]:
@@ -157,4 +181,4 @@ def __getattr__(self, name: str) -> Optional[Any]:
157181 return result
158182 self ._result [name ] = None
159183 self ._ready_count -= 1
160- return result
184+ return result . get ()
0 commit comments