@@ -201,25 +201,24 @@ async def __aenter__(self) -> Self:
201
201
"""
202
202
async with self ._enter_lock :
203
203
if self ._running_count == 0 :
204
- self ._exit_stack = AsyncExitStack ()
205
-
206
- self ._read_stream , self ._write_stream = await self ._exit_stack .enter_async_context (
207
- self .client_streams ()
208
- )
209
- client = ClientSession (
210
- read_stream = self ._read_stream ,
211
- write_stream = self ._write_stream ,
212
- sampling_callback = self ._sampling_callback if self .allow_sampling else None ,
213
- logging_callback = self .log_handler ,
214
- read_timeout_seconds = timedelta (seconds = self .read_timeout ),
215
- )
216
- self ._client = await self ._exit_stack .enter_async_context (client )
217
-
218
- with anyio .fail_after (self .timeout ):
219
- await self ._client .initialize ()
220
-
221
- if log_level := self .log_level :
222
- await self ._client .set_logging_level (log_level )
204
+ async with AsyncExitStack () as exit_stack :
205
+ self ._read_stream , self ._write_stream = await exit_stack .enter_async_context (self .client_streams ())
206
+ client = ClientSession (
207
+ read_stream = self ._read_stream ,
208
+ write_stream = self ._write_stream ,
209
+ sampling_callback = self ._sampling_callback if self .allow_sampling else None ,
210
+ logging_callback = self .log_handler ,
211
+ read_timeout_seconds = timedelta (seconds = self .read_timeout ),
212
+ )
213
+ self ._client = await exit_stack .enter_async_context (client )
214
+
215
+ with anyio .fail_after (self .timeout ):
216
+ await self ._client .initialize ()
217
+
218
+ if log_level := self .log_level :
219
+ await self ._client .set_logging_level (log_level )
220
+
221
+ self ._exit_stack = exit_stack .pop_all ()
223
222
self ._running_count += 1
224
223
return self
225
224
0 commit comments