@@ -124,7 +124,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
124
124
_IS_SYNC = False
125
125
126
126
127
- class AsyncConnection :
127
+ class AsyncBaseConnection :
128
+ """A base connection object for server and kms connections."""
129
+
130
+ def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
131
+ self .conn = conn
132
+ self .socket_checker : SocketChecker = SocketChecker ()
133
+ self .cancel_context : _CancellationContext = _CancellationContext ()
134
+ self .is_sdam = False
135
+ self .closed = False
136
+ self .last_timeout : float | None = None
137
+ self .more_to_come = False
138
+ self .opts = opts
139
+ self .max_wire_version = - 1
140
+
141
+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
142
+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
143
+ if timeout == self .last_timeout :
144
+ return
145
+ self .last_timeout = timeout
146
+ self .conn .get_conn .settimeout (timeout )
147
+
148
+ def apply_timeout (
149
+ self , client : AsyncMongoClient , cmd : Optional [MutableMapping [str , Any ]]
150
+ ) -> Optional [float ]:
151
+ # CSOT: use remaining timeout when set.
152
+ timeout = _csot .remaining ()
153
+ if timeout is None :
154
+ # Reset the socket timeout unless we're performing a streaming monitor check.
155
+ if not self .more_to_come :
156
+ self .set_conn_timeout (self .opts .socket_timeout )
157
+ return None
158
+ # RTT validation.
159
+ rtt = _csot .get_rtt ()
160
+ if rtt is None :
161
+ rtt = self .connect_rtt
162
+ max_time_ms = timeout - rtt
163
+ if max_time_ms < 0 :
164
+ timeout_details = _get_timeout_details (self .opts )
165
+ formatted = format_timeout_details (timeout_details )
166
+ # CSOT: raise an error without running the command since we know it will time out.
167
+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
168
+ if self .max_wire_version != - 1 :
169
+ raise ExecutionTimeout (
170
+ errmsg ,
171
+ 50 ,
172
+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
173
+ self .max_wire_version ,
174
+ )
175
+ else :
176
+ raise TimeoutError (errmsg )
177
+ if cmd is not None :
178
+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
179
+ self .set_conn_timeout (timeout )
180
+ return timeout
181
+
182
+ async def close_conn (self , reason : Optional [str ]) -> None :
183
+ """Close this connection with a reason."""
184
+ if self .closed :
185
+ return
186
+ await self ._close_conn ()
187
+
188
+ async def _close_conn (self ) -> None :
189
+ """Close this connection."""
190
+ if self .closed :
191
+ return
192
+ self .closed = True
193
+ self .cancel_context .cancel ()
194
+ # Note: We catch exceptions to avoid spurious errors on interpreter
195
+ # shutdown.
196
+ try :
197
+ await self .conn .close ()
198
+ except Exception : # noqa: S110
199
+ pass
200
+
201
+ def conn_closed (self ) -> bool :
202
+ """Return True if we know socket has been closed, False otherwise."""
203
+ if _IS_SYNC :
204
+ return self .socket_checker .socket_closed (self .conn .get_conn )
205
+ else :
206
+ return self .conn .is_closing ()
207
+
208
+
209
+ class AsyncConnection (AsyncBaseConnection ):
128
210
"""Store a connection with some metadata.
129
211
130
212
:param conn: a raw connection object
@@ -142,29 +224,27 @@ def __init__(
142
224
id : int ,
143
225
is_sdam : bool ,
144
226
):
227
+ super ().__init__ (conn , pool .opts )
145
228
self .pool_ref = weakref .ref (pool )
146
- self .conn = conn
147
- self .address = address
148
- self .id = id
229
+ self .address : tuple [str , int ] = address
230
+ self .id : int = id
149
231
self .is_sdam = is_sdam
150
- self .closed = False
151
232
self .last_checkin_time = time .monotonic ()
152
233
self .performed_handshake = False
153
234
self .is_writable : bool = False
154
235
self .max_wire_version = MAX_WIRE_VERSION
155
- self .max_bson_size = MAX_BSON_SIZE
156
- self .max_message_size = MAX_MESSAGE_SIZE
157
- self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
236
+ self .max_bson_size : int = MAX_BSON_SIZE
237
+ self .max_message_size : int = MAX_MESSAGE_SIZE
238
+ self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
158
239
self .supports_sessions = False
159
240
self .hello_ok : bool = False
160
- self .is_mongos = False
241
+ self .is_mongos : bool = False
161
242
self .op_msg_enabled = False
162
243
self .listeners = pool .opts ._event_listeners
163
244
self .enabled_for_cmap = pool .enabled_for_cmap
164
245
self .enabled_for_logging = pool .enabled_for_logging
165
246
self .compression_settings = pool .opts ._compression_settings
166
247
self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
167
- self .socket_checker : SocketChecker = SocketChecker ()
168
248
self .oidc_token_gen_id : Optional [int ] = None
169
249
# Support for mechanism negotiation on the initial handshake.
170
250
self .negotiated_mechs : Optional [list [str ]] = None
@@ -175,9 +255,6 @@ def __init__(
175
255
self .pool_gen = pool .gen
176
256
self .generation = self .pool_gen .get_overall ()
177
257
self .ready = False
178
- self .cancel_context : _CancellationContext = _CancellationContext ()
179
- self .opts = pool .opts
180
- self .more_to_come : bool = False
181
258
# For load balancer support.
182
259
self .service_id : Optional [ObjectId ] = None
183
260
self .server_connection_id : Optional [int ] = None
@@ -193,44 +270,6 @@ def __init__(
193
270
# For gossiping $clusterTime from the connection handshake to the client.
194
271
self ._cluster_time = None
195
272
196
- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
197
- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
198
- if timeout == self .last_timeout :
199
- return
200
- self .last_timeout = timeout
201
- self .conn .get_conn .settimeout (timeout )
202
-
203
- def apply_timeout (
204
- self , client : AsyncMongoClient , cmd : Optional [MutableMapping [str , Any ]]
205
- ) -> Optional [float ]:
206
- # CSOT: use remaining timeout when set.
207
- timeout = _csot .remaining ()
208
- if timeout is None :
209
- # Reset the socket timeout unless we're performing a streaming monitor check.
210
- if not self .more_to_come :
211
- self .set_conn_timeout (self .opts .socket_timeout )
212
- return None
213
- # RTT validation.
214
- rtt = _csot .get_rtt ()
215
- if rtt is None :
216
- rtt = self .connect_rtt
217
- max_time_ms = timeout - rtt
218
- if max_time_ms < 0 :
219
- timeout_details = _get_timeout_details (self .opts )
220
- formatted = format_timeout_details (timeout_details )
221
- # CSOT: raise an error without running the command since we know it will time out.
222
- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
223
- raise ExecutionTimeout (
224
- errmsg ,
225
- 50 ,
226
- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
227
- self .max_wire_version ,
228
- )
229
- if cmd is not None :
230
- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
231
- self .set_conn_timeout (timeout )
232
- return timeout
233
-
234
273
def pin_txn (self ) -> None :
235
274
self .pinned_txn = True
236
275
assert not self .pinned_cursor
@@ -555,9 +594,7 @@ def validate_session(
555
594
556
595
async def close_conn (self , reason : Optional [str ]) -> None :
557
596
"""Close this connection with a reason."""
558
- if self .closed :
559
- return
560
- await self ._close_conn ()
597
+ await super ().close_conn (reason )
561
598
if reason :
562
599
if self .enabled_for_cmap :
563
600
assert self .listeners is not None
@@ -574,26 +611,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
574
611
error = reason ,
575
612
)
576
613
577
- async def _close_conn (self ) -> None :
578
- """Close this connection."""
579
- if self .closed :
580
- return
581
- self .closed = True
582
- self .cancel_context .cancel ()
583
- # Note: We catch exceptions to avoid spurious errors on interpreter
584
- # shutdown.
585
- try :
586
- await self .conn .close ()
587
- except Exception : # noqa: S110
588
- pass
589
-
590
- def conn_closed (self ) -> bool :
591
- """Return True if we know socket has been closed, False otherwise."""
592
- if _IS_SYNC :
593
- return self .socket_checker .socket_closed (self .conn .get_conn )
594
- else :
595
- return self .conn .is_closing ()
596
-
597
614
def send_cluster_time (
598
615
self ,
599
616
command : MutableMapping [str , Any ],
0 commit comments