58
58
WaitQueueTimeoutError ,
59
59
)
60
60
from pymongo .hello import Hello , HelloCompat
61
+ from pymongo .helpers_shared import _get_timeout_details , format_timeout_details
61
62
from pymongo .lock import (
62
63
_async_cond_wait ,
63
64
_async_create_condition ,
79
80
SSLErrors ,
80
81
_CancellationContext ,
81
82
_configured_protocol_interface ,
82
- _get_timeout_details ,
83
83
_raise_connection_failure ,
84
- format_timeout_details ,
85
84
)
86
85
from pymongo .read_preferences import ReadPreference
87
86
from pymongo .server_api import _add_to_command
@@ -124,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
124
123
_IS_SYNC = False
125
124
126
125
127
- class AsyncConnection :
126
+ class AsyncBaseConnection :
127
+ """A base connection object for server and kms connections."""
128
+
129
+ def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130
+ self .conn = conn
131
+ self .socket_checker : SocketChecker = SocketChecker ()
132
+ self .cancel_context : _CancellationContext = _CancellationContext ()
133
+ self .is_sdam = False
134
+ self .closed = False
135
+ self .last_timeout : float | None = None
136
+ self .more_to_come = False
137
+ self .opts = opts
138
+ self .max_wire_version = - 1
139
+
140
+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141
+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142
+ if timeout == self .last_timeout :
143
+ return
144
+ self .last_timeout = timeout
145
+ self .conn .get_conn .settimeout (timeout )
146
+
147
+ def apply_timeout (
148
+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149
+ ) -> Optional [float ]:
150
+ # CSOT: use remaining timeout when set.
151
+ timeout = _csot .remaining ()
152
+ if timeout is None :
153
+ # Reset the socket timeout unless we're performing a streaming monitor check.
154
+ if not self .more_to_come :
155
+ self .set_conn_timeout (self .opts .socket_timeout )
156
+ return None
157
+ # RTT validation.
158
+ rtt = _csot .get_rtt ()
159
+ if rtt is None :
160
+ rtt = self .connect_rtt
161
+ max_time_ms = timeout - rtt
162
+ if max_time_ms < 0 :
163
+ timeout_details = _get_timeout_details (self .opts )
164
+ formatted = format_timeout_details (timeout_details )
165
+ # CSOT: raise an error without running the command since we know it will time out.
166
+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167
+ if self .max_wire_version != - 1 :
168
+ raise ExecutionTimeout (
169
+ errmsg ,
170
+ 50 ,
171
+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172
+ self .max_wire_version ,
173
+ )
174
+ else :
175
+ raise TimeoutError (errmsg )
176
+ if cmd is not None :
177
+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178
+ self .set_conn_timeout (timeout )
179
+ return timeout
180
+
181
+ async def close_conn (self , reason : Optional [str ]) -> None :
182
+ """Close this connection with a reason."""
183
+ if self .closed :
184
+ return
185
+ await self ._close_conn ()
186
+
187
+ async def _close_conn (self ) -> None :
188
+ """Close this connection."""
189
+ if self .closed :
190
+ return
191
+ self .closed = True
192
+ self .cancel_context .cancel ()
193
+ # Note: We catch exceptions to avoid spurious errors on interpreter
194
+ # shutdown.
195
+ try :
196
+ await self .conn .close ()
197
+ except Exception : # noqa: S110
198
+ pass
199
+
200
+ def conn_closed (self ) -> bool :
201
+ """Return True if we know socket has been closed, False otherwise."""
202
+ if _IS_SYNC :
203
+ return self .socket_checker .socket_closed (self .conn .get_conn )
204
+ else :
205
+ return self .conn .is_closing ()
206
+
207
+
208
+ class AsyncConnection (AsyncBaseConnection ):
128
209
"""Store a connection with some metadata.
129
210
130
211
:param conn: a raw connection object
@@ -142,29 +223,27 @@ def __init__(
142
223
id : int ,
143
224
is_sdam : bool ,
144
225
):
226
+ super ().__init__ (conn , pool .opts )
145
227
self .pool_ref = weakref .ref (pool )
146
- self .conn = conn
147
- self .address = address
148
- self .id = id
228
+ self .address : tuple [str , int ] = address
229
+ self .id : int = id
149
230
self .is_sdam = is_sdam
150
- self .closed = False
151
231
self .last_checkin_time = time .monotonic ()
152
232
self .performed_handshake = False
153
233
self .is_writable : bool = False
154
234
self .max_wire_version = MAX_WIRE_VERSION
155
- self .max_bson_size = MAX_BSON_SIZE
156
- self .max_message_size = MAX_MESSAGE_SIZE
157
- self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
235
+ self .max_bson_size : int = MAX_BSON_SIZE
236
+ self .max_message_size : int = MAX_MESSAGE_SIZE
237
+ self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
158
238
self .supports_sessions = False
159
239
self .hello_ok : bool = False
160
- self .is_mongos = False
240
+ self .is_mongos : bool = False
161
241
self .op_msg_enabled = False
162
242
self .listeners = pool .opts ._event_listeners
163
243
self .enabled_for_cmap = pool .enabled_for_cmap
164
244
self .enabled_for_logging = pool .enabled_for_logging
165
245
self .compression_settings = pool .opts ._compression_settings
166
246
self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
167
- self .socket_checker : SocketChecker = SocketChecker ()
168
247
self .oidc_token_gen_id : Optional [int ] = None
169
248
# Support for mechanism negotiation on the initial handshake.
170
249
self .negotiated_mechs : Optional [list [str ]] = None
@@ -175,9 +254,6 @@ def __init__(
175
254
self .pool_gen = pool .gen
176
255
self .generation = self .pool_gen .get_overall ()
177
256
self .ready = False
178
- self .cancel_context : _CancellationContext = _CancellationContext ()
179
- self .opts = pool .opts
180
- self .more_to_come : bool = False
181
257
# For load balancer support.
182
258
self .service_id : Optional [ObjectId ] = None
183
259
self .server_connection_id : Optional [int ] = None
@@ -193,44 +269,6 @@ def __init__(
193
269
# For gossiping $clusterTime from the connection handshake to the client.
194
270
self ._cluster_time = None
195
271
196
- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
197
- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
198
- if timeout == self .last_timeout :
199
- return
200
- self .last_timeout = timeout
201
- self .conn .get_conn .settimeout (timeout )
202
-
203
- def apply_timeout (
204
- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
205
- ) -> Optional [float ]:
206
- # CSOT: use remaining timeout when set.
207
- timeout = _csot .remaining ()
208
- if timeout is None :
209
- # Reset the socket timeout unless we're performing a streaming monitor check.
210
- if not self .more_to_come :
211
- self .set_conn_timeout (self .opts .socket_timeout )
212
- return None
213
- # RTT validation.
214
- rtt = _csot .get_rtt ()
215
- if rtt is None :
216
- rtt = self .connect_rtt
217
- max_time_ms = timeout - rtt
218
- if max_time_ms < 0 :
219
- timeout_details = _get_timeout_details (self .opts )
220
- formatted = format_timeout_details (timeout_details )
221
- # CSOT: raise an error without running the command since we know it will time out.
222
- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
223
- raise ExecutionTimeout (
224
- errmsg ,
225
- 50 ,
226
- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
227
- self .max_wire_version ,
228
- )
229
- if cmd is not None :
230
- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
231
- self .set_conn_timeout (timeout )
232
- return timeout
233
-
234
272
def pin_txn (self ) -> None :
235
273
self .pinned_txn = True
236
274
assert not self .pinned_cursor
@@ -574,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
574
612
error = reason ,
575
613
)
576
614
577
- async def _close_conn (self ) -> None :
578
- """Close this connection."""
579
- if self .closed :
580
- return
581
- self .closed = True
582
- self .cancel_context .cancel ()
583
- # Note: We catch exceptions to avoid spurious errors on interpreter
584
- # shutdown.
585
- try :
586
- await self .conn .close ()
587
- except Exception : # noqa: S110
588
- pass
589
-
590
- def conn_closed (self ) -> bool :
591
- """Return True if we know socket has been closed, False otherwise."""
592
- if _IS_SYNC :
593
- return self .socket_checker .socket_closed (self .conn .get_conn )
594
- else :
595
- return self .conn .is_closing ()
596
-
597
615
def send_cluster_time (
598
616
self ,
599
617
command : MutableMapping [str , Any ],
0 commit comments