|
44 | 44 | from pymongo.monitoring import _is_speculative_authenticate
|
45 | 45 | from pymongo.network_layer import (
|
46 | 46 | _UNPACK_COMPRESSION_HEADER,
|
47 |
| - _UNPACK_HEADER, |
48 |
| - async_receive_data, |
49 |
| - async_sendall, async_sendall_stream, async_receive_data_stream, |
| 47 | + _UNPACK_HEADER, async_sendall_stream, async_receive_data_stream, |
50 | 48 | )
|
51 | 49 |
|
52 | 50 | if TYPE_CHECKING:
|
|
64 | 62 | _IS_SYNC = False
|
65 | 63 |
|
66 | 64 |
|
67 |
| -async def command( |
68 |
| - conn: AsyncConnection, |
69 |
| - dbname: str, |
70 |
| - spec: MutableMapping[str, Any], |
71 |
| - is_mongos: bool, |
72 |
| - read_preference: Optional[_ServerMode], |
73 |
| - codec_options: CodecOptions[_DocumentType], |
74 |
| - session: Optional[AsyncClientSession], |
75 |
| - client: Optional[AsyncMongoClient], |
76 |
| - check: bool = True, |
77 |
| - allowable_errors: Optional[Sequence[Union[str, int]]] = None, |
78 |
| - address: Optional[_Address] = None, |
79 |
| - listeners: Optional[_EventListeners] = None, |
80 |
| - max_bson_size: Optional[int] = None, |
81 |
| - read_concern: Optional[ReadConcern] = None, |
82 |
| - parse_write_concern_error: bool = False, |
83 |
| - collation: Optional[_CollationIn] = None, |
84 |
| - compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, |
85 |
| - use_op_msg: bool = False, |
86 |
| - unacknowledged: bool = False, |
87 |
| - user_fields: Optional[Mapping[str, Any]] = None, |
88 |
| - exhaust_allowed: bool = False, |
89 |
| - write_concern: Optional[WriteConcern] = None, |
90 |
| -) -> _DocumentType: |
91 |
| - """Execute a command over the socket, or raise socket.error. |
92 |
| -
|
93 |
| - :param conn: a AsyncConnection instance |
94 |
| - :param dbname: name of the database on which to run the command |
95 |
| - :param spec: a command document as an ordered dict type, eg SON. |
96 |
| - :param is_mongos: are we connected to a mongos? |
97 |
| - :param read_preference: a read preference |
98 |
| - :param codec_options: a CodecOptions instance |
99 |
| - :param session: optional AsyncClientSession instance. |
100 |
| - :param client: optional AsyncMongoClient instance for updating $clusterTime. |
101 |
| - :param check: raise OperationFailure if there are errors |
102 |
| - :param allowable_errors: errors to ignore if `check` is True |
103 |
| - :param address: the (host, port) of `conn` |
104 |
| - :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` |
105 |
| - :param max_bson_size: The maximum encoded bson size for this server |
106 |
| - :param read_concern: The read concern for this command. |
107 |
| - :param parse_write_concern_error: Whether to parse the ``writeConcernError`` |
108 |
| - field in the command response. |
109 |
| - :param collation: The collation for this command. |
110 |
| - :param compression_ctx: optional compression Context. |
111 |
| - :param use_op_msg: True if we should use OP_MSG. |
112 |
| - :param unacknowledged: True if this is an unacknowledged command. |
113 |
| - :param user_fields: Response fields that should be decoded |
114 |
| - using the TypeDecoders from codec_options, passed to |
115 |
| - bson._decode_all_selective. |
116 |
| - :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. |
117 |
| - """ |
118 |
| - print("Running socket command!") |
119 |
| - name = next(iter(spec)) |
120 |
| - ns = dbname + ".$cmd" |
121 |
| - speculative_hello = False |
122 |
| - |
123 |
| - # Publish the original command document, perhaps with lsid and $clusterTime. |
124 |
| - orig = spec |
125 |
| - if is_mongos and not use_op_msg: |
126 |
| - assert read_preference is not None |
127 |
| - spec = message._maybe_add_read_preference(spec, read_preference) |
128 |
| - if read_concern and not (session and session.in_transaction): |
129 |
| - if read_concern.level: |
130 |
| - spec["readConcern"] = read_concern.document |
131 |
| - if session: |
132 |
| - session._update_read_concern(spec, conn) |
133 |
| - if collation is not None: |
134 |
| - spec["collation"] = collation |
135 |
| - |
136 |
| - publish = listeners is not None and listeners.enabled_for_commands |
137 |
| - start = datetime.datetime.now() |
138 |
| - if publish: |
139 |
| - speculative_hello = _is_speculative_authenticate(name, spec) |
140 |
| - |
141 |
| - if compression_ctx and name.lower() in _NO_COMPRESSION: |
142 |
| - compression_ctx = None |
143 |
| - |
144 |
| - if client and client._encrypter and not client._encrypter._bypass_auto_encryption: |
145 |
| - spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) |
146 |
| - |
147 |
| - # Support CSOT |
148 |
| - if client: |
149 |
| - conn.apply_timeout(client, spec) |
150 |
| - _csot.apply_write_concern(spec, write_concern) |
151 |
| - |
152 |
| - if use_op_msg: |
153 |
| - flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 |
154 |
| - flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 |
155 |
| - request_id, msg, size, max_doc_size = message._op_msg( |
156 |
| - flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx |
157 |
| - ) |
158 |
| - # If this is an unacknowledged write then make sure the encoded doc(s) |
159 |
| - # are small enough, otherwise rely on the server to return an error. |
160 |
| - if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: |
161 |
| - message._raise_document_too_large(name, size, max_bson_size) |
162 |
| - else: |
163 |
| - request_id, msg, size = message._query( |
164 |
| - 0, ns, 0, -1, spec, None, codec_options, compression_ctx |
165 |
| - ) |
166 |
| - |
167 |
| - if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: |
168 |
| - message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) |
169 |
| - if client is not None: |
170 |
| - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
171 |
| - _debug_log( |
172 |
| - _COMMAND_LOGGER, |
173 |
| - clientId=client._topology_settings._topology_id, |
174 |
| - message=_CommandStatusMessage.STARTED, |
175 |
| - command=spec, |
176 |
| - commandName=next(iter(spec)), |
177 |
| - databaseName=dbname, |
178 |
| - requestId=request_id, |
179 |
| - operationId=request_id, |
180 |
| - driverConnectionId=conn.id, |
181 |
| - serverConnectionId=conn.server_connection_id, |
182 |
| - serverHost=conn.address[0], |
183 |
| - serverPort=conn.address[1], |
184 |
| - serviceId=conn.service_id, |
185 |
| - ) |
186 |
| - if publish: |
187 |
| - assert listeners is not None |
188 |
| - assert address is not None |
189 |
| - listeners.publish_command_start( |
190 |
| - orig, |
191 |
| - dbname, |
192 |
| - request_id, |
193 |
| - address, |
194 |
| - conn.server_connection_id, |
195 |
| - service_id=conn.service_id, |
196 |
| - ) |
197 |
| - |
198 |
| - try: |
199 |
| - await async_sendall(conn.conn, msg) |
200 |
| - if use_op_msg and unacknowledged: |
201 |
| - # Unacknowledged, fake a successful command response. |
202 |
| - reply = None |
203 |
| - response_doc: _DocumentOut = {"ok": 1} |
204 |
| - else: |
205 |
| - reply = await receive_message(conn, request_id) |
206 |
| - conn.more_to_come = reply.more_to_come |
207 |
| - unpacked_docs = reply.unpack_response( |
208 |
| - codec_options=codec_options, user_fields=user_fields |
209 |
| - ) |
210 |
| - |
211 |
| - response_doc = unpacked_docs[0] |
212 |
| - if client: |
213 |
| - await client._process_response(response_doc, session) |
214 |
| - if check: |
215 |
| - helpers_shared._check_command_response( |
216 |
| - response_doc, |
217 |
| - conn.max_wire_version, |
218 |
| - allowable_errors, |
219 |
| - parse_write_concern_error=parse_write_concern_error, |
220 |
| - ) |
221 |
| - except Exception as exc: |
222 |
| - duration = datetime.datetime.now() - start |
223 |
| - if isinstance(exc, (NotPrimaryError, OperationFailure)): |
224 |
| - failure: _DocumentOut = exc.details # type: ignore[assignment] |
225 |
| - else: |
226 |
| - failure = message._convert_exception(exc) |
227 |
| - if client is not None: |
228 |
| - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
229 |
| - _debug_log( |
230 |
| - _COMMAND_LOGGER, |
231 |
| - clientId=client._topology_settings._topology_id, |
232 |
| - message=_CommandStatusMessage.FAILED, |
233 |
| - durationMS=duration, |
234 |
| - failure=failure, |
235 |
| - commandName=next(iter(spec)), |
236 |
| - databaseName=dbname, |
237 |
| - requestId=request_id, |
238 |
| - operationId=request_id, |
239 |
| - driverConnectionId=conn.id, |
240 |
| - serverConnectionId=conn.server_connection_id, |
241 |
| - serverHost=conn.address[0], |
242 |
| - serverPort=conn.address[1], |
243 |
| - serviceId=conn.service_id, |
244 |
| - isServerSideError=isinstance(exc, OperationFailure), |
245 |
| - ) |
246 |
| - if publish: |
247 |
| - assert listeners is not None |
248 |
| - assert address is not None |
249 |
| - listeners.publish_command_failure( |
250 |
| - duration, |
251 |
| - failure, |
252 |
| - name, |
253 |
| - request_id, |
254 |
| - address, |
255 |
| - conn.server_connection_id, |
256 |
| - service_id=conn.service_id, |
257 |
| - database_name=dbname, |
258 |
| - ) |
259 |
| - raise |
260 |
| - duration = datetime.datetime.now() - start |
261 |
| - if client is not None: |
262 |
| - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
263 |
| - _debug_log( |
264 |
| - _COMMAND_LOGGER, |
265 |
| - clientId=client._topology_settings._topology_id, |
266 |
| - message=_CommandStatusMessage.SUCCEEDED, |
267 |
| - durationMS=duration, |
268 |
| - reply=response_doc, |
269 |
| - commandName=next(iter(spec)), |
270 |
| - databaseName=dbname, |
271 |
| - requestId=request_id, |
272 |
| - operationId=request_id, |
273 |
| - driverConnectionId=conn.id, |
274 |
| - serverConnectionId=conn.server_connection_id, |
275 |
| - serverHost=conn.address[0], |
276 |
| - serverPort=conn.address[1], |
277 |
| - serviceId=conn.service_id, |
278 |
| - speculative_authenticate="speculativeAuthenticate" in orig, |
279 |
| - ) |
280 |
| - if publish: |
281 |
| - assert listeners is not None |
282 |
| - assert address is not None |
283 |
| - listeners.publish_command_success( |
284 |
| - duration, |
285 |
| - response_doc, |
286 |
| - name, |
287 |
| - request_id, |
288 |
| - address, |
289 |
| - conn.server_connection_id, |
290 |
| - service_id=conn.service_id, |
291 |
| - speculative_hello=speculative_hello, |
292 |
| - database_name=dbname, |
293 |
| - ) |
294 |
| - |
295 |
| - if client and client._encrypter and reply: |
296 |
| - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) |
297 |
| - response_doc = cast( |
298 |
| - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] |
299 |
| - ) |
300 |
| - |
301 |
| - return response_doc # type: ignore[return-value] |
302 |
| - |
303 | 65 | async def command_stream(
|
304 | 66 | conn: AsyncConnectionStream,
|
305 | 67 | dbname: str,
|
@@ -537,50 +299,6 @@ async def command_stream(
|
537 | 299 | return response_doc # type: ignore[return-value]
|
538 | 300 |
|
539 | 301 |
|
540 |
| - |
541 |
| -async def receive_message( |
542 |
| - conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE |
543 |
| -) -> Union[_OpReply, _OpMsg]: |
544 |
| - """Receive a raw BSON message or raise socket.error.""" |
545 |
| - if _csot.get_timeout(): |
546 |
| - deadline = _csot.get_deadline() |
547 |
| - else: |
548 |
| - timeout = conn.conn.gettimeout() |
549 |
| - if timeout: |
550 |
| - deadline = time.monotonic() + timeout |
551 |
| - else: |
552 |
| - deadline = None |
553 |
| - # Ignore the response's request id. |
554 |
| - length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline)) |
555 |
| - # No request_id for exhaust cursor "getMore". |
556 |
| - if request_id is not None: |
557 |
| - if request_id != response_to: |
558 |
| - raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") |
559 |
| - if length <= 16: |
560 |
| - raise ProtocolError( |
561 |
| - f"Message length ({length!r}) not longer than standard message header size (16)" |
562 |
| - ) |
563 |
| - if length > max_message_size: |
564 |
| - raise ProtocolError( |
565 |
| - f"Message length ({length!r}) is larger than server max " |
566 |
| - f"message size ({max_message_size!r})" |
567 |
| - ) |
568 |
| - if op_code == 2012: |
569 |
| - op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( |
570 |
| - await async_receive_data(conn, 9, deadline) |
571 |
| - ) |
572 |
| - data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id) |
573 |
| - else: |
574 |
| - data = await async_receive_data(conn, length - 16, deadline) |
575 |
| - |
576 |
| - try: |
577 |
| - unpack_reply = _UNPACK_REPLY[op_code] |
578 |
| - except KeyError: |
579 |
| - raise ProtocolError( |
580 |
| - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" |
581 |
| - ) from None |
582 |
| - return unpack_reply(data) |
583 |
| - |
584 | 302 | async def receive_message_stream(
|
585 | 303 | conn: StreamReader, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
586 | 304 | ) -> Union[_OpReply, _OpMsg]:
|
@@ -611,7 +329,7 @@ async def receive_message_stream(
|
611 | 329 | )
|
612 | 330 | if op_code == 2012:
|
613 | 331 | op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
|
614 |
| - await async_receive_data(conn, 9, deadline) |
| 332 | + await async_receive_data_stream(conn, 9, deadline) |
615 | 333 | )
|
616 | 334 | data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id)
|
617 | 335 | else:
|
|
0 commit comments