19
19
import contextlib
20
20
import enum
21
21
import socket
22
+ import time as time # noqa: PLC0414 # needed in sync version
22
23
import uuid
23
24
import weakref
24
25
from copy import deepcopy
67
68
EncryptedCollectionError ,
68
69
EncryptionError ,
69
70
InvalidOperation ,
70
- PyMongoError ,
71
+ NetworkTimeout ,
71
72
ServerSelectionTimeoutError ,
72
73
)
73
74
from pymongo .network_layer import BLOCKING_IO_ERRORS , sendall
80
81
from pymongo .synchronous .cursor import Cursor
81
82
from pymongo .synchronous .database import Database
82
83
from pymongo .synchronous .mongo_client import MongoClient
83
- from pymongo .synchronous .pool import _configured_socket , _raise_connection_failure
84
+ from pymongo .synchronous .pool import (
85
+ _configured_socket ,
86
+ _get_timeout_details ,
87
+ _raise_connection_failure ,
88
+ )
84
89
from pymongo .typings import _DocumentType , _DocumentTypeArg
85
90
from pymongo .uri_parser import parse_host
86
91
from pymongo .write_concern import WriteConcern
87
92
88
93
if TYPE_CHECKING :
89
94
from pymongocrypt .mongocrypt import MongoCryptKmsContext
90
95
96
+ from pymongo .pyopenssl_context import _sslConn
97
+ from pymongo .typings import _Address
98
+
91
99
92
100
_IS_SYNC = True
93
101
103
111
_KEY_VAULT_OPTS = CodecOptions (document_class = RawBSONDocument )
104
112
105
113
114
+ def _connect_kms (address : _Address , opts : PoolOptions ) -> Union [socket .socket , _sslConn ]:
115
+ try :
116
+ return _configured_socket (address , opts )
117
+ except Exception as exc :
118
+ _raise_connection_failure (address , exc , timeout_details = _get_timeout_details (opts ))
119
+
120
+
106
121
@contextlib .contextmanager
107
122
def _wrap_encryption_errors () -> Iterator [None ]:
108
123
"""Context manager to wrap encryption related errors."""
@@ -166,18 +181,22 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
166
181
None , # crlfile
167
182
False , # allow_invalid_certificates
168
183
False , # allow_invalid_hostnames
169
- False ,
170
- ) # disable_ocsp_endpoint_check
184
+ False , # disable_ocsp_endpoint_check
185
+ )
171
186
# CSOT: set timeout for socket creation.
172
187
connect_timeout = max (_csot .clamp_remaining (_KMS_CONNECT_TIMEOUT ), 0.001 )
173
188
opts = PoolOptions (
174
189
connect_timeout = connect_timeout ,
175
190
socket_timeout = connect_timeout ,
176
191
ssl_context = ctx ,
177
192
)
178
- host , port = parse_host (endpoint , _HTTPS_PORT )
193
+ address = parse_host (endpoint , _HTTPS_PORT )
194
+ sleep_u = kms_context .usleep
195
+ if sleep_u :
196
+ sleep_sec = float (sleep_u ) / 1e6
197
+ time .sleep (sleep_sec )
179
198
try :
180
- conn = _configured_socket (( host , port ) , opts )
199
+ conn = _connect_kms ( address , opts )
181
200
try :
182
201
sendall (conn , message )
183
202
while kms_context .bytes_needed > 0 :
@@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
194
213
if not data :
195
214
raise OSError ("KMS connection closed" )
196
215
kms_context .feed (data )
197
- # Async raises an OSError instead of returning empty bytes
198
- except OSError as err :
199
- raise OSError ("KMS connection closed" ) from err
200
- except BLOCKING_IO_ERRORS :
201
- raise socket .timeout ("timed out" ) from None
216
+ except MongoCryptError :
217
+ raise # Propagate MongoCryptError errors directly.
218
+ except Exception as exc :
219
+ # Wrap I/O errors in PyMongo exceptions.
220
+ if isinstance (exc , BLOCKING_IO_ERRORS ):
221
+ exc = socket .timeout ("timed out" )
222
+ _raise_connection_failure (address , exc , timeout_details = _get_timeout_details (opts ))
202
223
finally :
203
224
conn .close ()
204
- except (PyMongoError , MongoCryptError ):
205
- raise # Propagate pymongo errors directly.
206
- except asyncio .CancelledError :
207
- raise
208
- except Exception as error :
209
- # Wrap I/O errors in PyMongo exceptions.
210
- _raise_connection_failure ((host , port ), error )
225
+ except MongoCryptError :
226
+ raise # Propagate MongoCryptError errors directly.
227
+ except Exception as exc :
228
+ remaining = _csot .remaining ()
229
+ if isinstance (exc , NetworkTimeout ) or (remaining is not None and remaining <= 0 ):
230
+ raise
231
+ # Mark this attempt as failed and defer to libmongocrypt to retry.
232
+ try :
233
+ kms_context .fail ()
234
+ except MongoCryptError as final_err :
235
+ exc = MongoCryptError (
236
+ f"{ final_err } , last attempt failed with: { exc } " , final_err .code
237
+ )
238
+ raise exc from final_err
211
239
212
240
def collection_info (self , database : str , filter : bytes ) -> Optional [bytes ]:
213
241
"""Get the collection info for a namespace.
0 commit comments