|
18 | 18 | import datetime
|
19 | 19 | import logging
|
20 | 20 | import time
|
| 21 | +from asyncio import streams, StreamReader |
21 | 22 | from typing import (
|
22 | 23 | TYPE_CHECKING,
|
23 | 24 | Any,
|
|
45 | 46 | _UNPACK_COMPRESSION_HEADER,
|
46 | 47 | _UNPACK_HEADER,
|
47 | 48 | async_receive_data,
|
48 |
| - async_sendall, |
| 49 | + async_sendall, async_sendall_stream, async_receive_data_stream, |
49 | 50 | )
|
50 | 51 |
|
51 | 52 | if TYPE_CHECKING:
|
52 | 53 | from bson import CodecOptions
|
53 | 54 | from pymongo.asynchronous.client_session import AsyncClientSession
|
54 | 55 | from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
55 |
| - from pymongo.asynchronous.pool import AsyncConnection |
| 56 | + from pymongo.asynchronous.pool import AsyncConnection, AsyncStreamConnection, AsyncConnectionStream |
56 | 57 | from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
|
57 | 58 | from pymongo.monitoring import _EventListeners
|
58 | 59 | from pymongo.read_concern import ReadConcern
|
@@ -114,6 +115,7 @@ async def command(
|
114 | 115 | bson._decode_all_selective.
|
115 | 116 | :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
|
116 | 117 | """
|
| 118 | + print("Running socket command!") |
117 | 119 | name = next(iter(spec))
|
118 | 120 | ns = dbname + ".$cmd"
|
119 | 121 | speculative_hello = False
|
@@ -298,6 +300,243 @@ async def command(
|
298 | 300 |
|
299 | 301 | return response_doc # type: ignore[return-value]
|
300 | 302 |
|
| 303 | +async def command_stream( |
| 304 | + conn: AsyncConnectionStream, |
| 305 | + dbname: str, |
| 306 | + spec: MutableMapping[str, Any], |
| 307 | + is_mongos: bool, |
| 308 | + read_preference: Optional[_ServerMode], |
| 309 | + codec_options: CodecOptions[_DocumentType], |
| 310 | + session: Optional[AsyncClientSession], |
| 311 | + client: Optional[AsyncMongoClient], |
| 312 | + check: bool = True, |
| 313 | + allowable_errors: Optional[Sequence[Union[str, int]]] = None, |
| 314 | + address: Optional[_Address] = None, |
| 315 | + listeners: Optional[_EventListeners] = None, |
| 316 | + max_bson_size: Optional[int] = None, |
| 317 | + read_concern: Optional[ReadConcern] = None, |
| 318 | + parse_write_concern_error: bool = False, |
| 319 | + collation: Optional[_CollationIn] = None, |
| 320 | + compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, |
| 321 | + use_op_msg: bool = False, |
| 322 | + unacknowledged: bool = False, |
| 323 | + user_fields: Optional[Mapping[str, Any]] = None, |
| 324 | + exhaust_allowed: bool = False, |
| 325 | + write_concern: Optional[WriteConcern] = None, |
| 326 | +) -> _DocumentType: |
| 327 | + """Execute a command over the socket, or raise socket.error. |
| 328 | +
|
| 329 | + :param conn: a AsyncConnection instance |
| 330 | + :param dbname: name of the database on which to run the command |
| 331 | + :param spec: a command document as an ordered dict type, eg SON. |
| 332 | + :param is_mongos: are we connected to a mongos? |
| 333 | + :param read_preference: a read preference |
| 334 | + :param codec_options: a CodecOptions instance |
| 335 | + :param session: optional AsyncClientSession instance. |
| 336 | + :param client: optional AsyncMongoClient instance for updating $clusterTime. |
| 337 | + :param check: raise OperationFailure if there are errors |
| 338 | + :param allowable_errors: errors to ignore if `check` is True |
| 339 | + :param address: the (host, port) of `conn` |
| 340 | + :param listeners: An instance of :class:`~pymongo.monitoring.EventListeners` |
| 341 | + :param max_bson_size: The maximum encoded bson size for this server |
| 342 | + :param read_concern: The read concern for this command. |
| 343 | + :param parse_write_concern_error: Whether to parse the ``writeConcernError`` |
| 344 | + field in the command response. |
| 345 | + :param collation: The collation for this command. |
| 346 | + :param compression_ctx: optional compression Context. |
| 347 | + :param use_op_msg: True if we should use OP_MSG. |
| 348 | + :param unacknowledged: True if this is an unacknowledged command. |
| 349 | + :param user_fields: Response fields that should be decoded |
| 350 | + using the TypeDecoders from codec_options, passed to |
| 351 | + bson._decode_all_selective. |
| 352 | + :param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed. |
| 353 | + """ |
| 354 | + # print("Running stream command!") |
| 355 | + name = next(iter(spec)) |
| 356 | + ns = dbname + ".$cmd" |
| 357 | + speculative_hello = False |
| 358 | + |
| 359 | + # Publish the original command document, perhaps with lsid and $clusterTime. |
| 360 | + orig = spec |
| 361 | + if is_mongos and not use_op_msg: |
| 362 | + assert read_preference is not None |
| 363 | + spec = message._maybe_add_read_preference(spec, read_preference) |
| 364 | + if read_concern and not (session and session.in_transaction): |
| 365 | + if read_concern.level: |
| 366 | + spec["readConcern"] = read_concern.document |
| 367 | + if session: |
| 368 | + session._update_read_concern(spec, conn) |
| 369 | + if collation is not None: |
| 370 | + spec["collation"] = collation |
| 371 | + |
| 372 | + publish = listeners is not None and listeners.enabled_for_commands |
| 373 | + start = datetime.datetime.now() |
| 374 | + if publish: |
| 375 | + speculative_hello = _is_speculative_authenticate(name, spec) |
| 376 | + |
| 377 | + if compression_ctx and name.lower() in _NO_COMPRESSION: |
| 378 | + compression_ctx = None |
| 379 | + |
| 380 | + if client and client._encrypter and not client._encrypter._bypass_auto_encryption: |
| 381 | + spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options) |
| 382 | + |
| 383 | + # Support CSOT |
| 384 | + if client: |
| 385 | + conn.apply_timeout(client, spec) |
| 386 | + _csot.apply_write_concern(spec, write_concern) |
| 387 | + |
| 388 | + if use_op_msg: |
| 389 | + flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 |
| 390 | + flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 |
| 391 | + request_id, msg, size, max_doc_size = message._op_msg( |
| 392 | + flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx |
| 393 | + ) |
| 394 | + # If this is an unacknowledged write then make sure the encoded doc(s) |
| 395 | + # are small enough, otherwise rely on the server to return an error. |
| 396 | + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: |
| 397 | + message._raise_document_too_large(name, size, max_bson_size) |
| 398 | + else: |
| 399 | + request_id, msg, size = message._query( |
| 400 | + 0, ns, 0, -1, spec, None, codec_options, compression_ctx |
| 401 | + ) |
| 402 | + |
| 403 | + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: |
| 404 | + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) |
| 405 | + if client is not None: |
| 406 | + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
| 407 | + _debug_log( |
| 408 | + _COMMAND_LOGGER, |
| 409 | + clientId=client._topology_settings._topology_id, |
| 410 | + message=_CommandStatusMessage.STARTED, |
| 411 | + command=spec, |
| 412 | + commandName=next(iter(spec)), |
| 413 | + databaseName=dbname, |
| 414 | + requestId=request_id, |
| 415 | + operationId=request_id, |
| 416 | + driverConnectionId=conn.id, |
| 417 | + serverConnectionId=conn.server_connection_id, |
| 418 | + serverHost=conn.address[0], |
| 419 | + serverPort=conn.address[1], |
| 420 | + serviceId=conn.service_id, |
| 421 | + ) |
| 422 | + if publish: |
| 423 | + assert listeners is not None |
| 424 | + assert address is not None |
| 425 | + listeners.publish_command_start( |
| 426 | + orig, |
| 427 | + dbname, |
| 428 | + request_id, |
| 429 | + address, |
| 430 | + conn.server_connection_id, |
| 431 | + service_id=conn.service_id, |
| 432 | + ) |
| 433 | + |
| 434 | + try: |
| 435 | + await async_sendall_stream(conn.conn[1], msg) |
| 436 | + if use_op_msg and unacknowledged: |
| 437 | + # Unacknowledged, fake a successful command response. |
| 438 | + reply = None |
| 439 | + response_doc: _DocumentOut = {"ok": 1} |
| 440 | + else: |
| 441 | + reply = await receive_message_stream(conn.conn[0], request_id) |
| 442 | + conn.more_to_come = reply.more_to_come |
| 443 | + unpacked_docs = reply.unpack_response( |
| 444 | + codec_options=codec_options, user_fields=user_fields |
| 445 | + ) |
| 446 | + |
| 447 | + response_doc = unpacked_docs[0] |
| 448 | + if client: |
| 449 | + await client._process_response(response_doc, session) |
| 450 | + if check: |
| 451 | + helpers_shared._check_command_response( |
| 452 | + response_doc, |
| 453 | + conn.max_wire_version, |
| 454 | + allowable_errors, |
| 455 | + parse_write_concern_error=parse_write_concern_error, |
| 456 | + ) |
| 457 | + except Exception as exc: |
| 458 | + duration = datetime.datetime.now() - start |
| 459 | + if isinstance(exc, (NotPrimaryError, OperationFailure)): |
| 460 | + failure: _DocumentOut = exc.details # type: ignore[assignment] |
| 461 | + else: |
| 462 | + failure = message._convert_exception(exc) |
| 463 | + if client is not None: |
| 464 | + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
| 465 | + _debug_log( |
| 466 | + _COMMAND_LOGGER, |
| 467 | + clientId=client._topology_settings._topology_id, |
| 468 | + message=_CommandStatusMessage.FAILED, |
| 469 | + durationMS=duration, |
| 470 | + failure=failure, |
| 471 | + commandName=next(iter(spec)), |
| 472 | + databaseName=dbname, |
| 473 | + requestId=request_id, |
| 474 | + operationId=request_id, |
| 475 | + driverConnectionId=conn.id, |
| 476 | + serverConnectionId=conn.server_connection_id, |
| 477 | + serverHost=conn.address[0], |
| 478 | + serverPort=conn.address[1], |
| 479 | + serviceId=conn.service_id, |
| 480 | + isServerSideError=isinstance(exc, OperationFailure), |
| 481 | + ) |
| 482 | + if publish: |
| 483 | + assert listeners is not None |
| 484 | + assert address is not None |
| 485 | + listeners.publish_command_failure( |
| 486 | + duration, |
| 487 | + failure, |
| 488 | + name, |
| 489 | + request_id, |
| 490 | + address, |
| 491 | + conn.server_connection_id, |
| 492 | + service_id=conn.service_id, |
| 493 | + database_name=dbname, |
| 494 | + ) |
| 495 | + raise |
| 496 | + duration = datetime.datetime.now() - start |
| 497 | + if client is not None: |
| 498 | + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): |
| 499 | + _debug_log( |
| 500 | + _COMMAND_LOGGER, |
| 501 | + clientId=client._topology_settings._topology_id, |
| 502 | + message=_CommandStatusMessage.SUCCEEDED, |
| 503 | + durationMS=duration, |
| 504 | + reply=response_doc, |
| 505 | + commandName=next(iter(spec)), |
| 506 | + databaseName=dbname, |
| 507 | + requestId=request_id, |
| 508 | + operationId=request_id, |
| 509 | + driverConnectionId=conn.id, |
| 510 | + serverConnectionId=conn.server_connection_id, |
| 511 | + serverHost=conn.address[0], |
| 512 | + serverPort=conn.address[1], |
| 513 | + serviceId=conn.service_id, |
| 514 | + speculative_authenticate="speculativeAuthenticate" in orig, |
| 515 | + ) |
| 516 | + if publish: |
| 517 | + assert listeners is not None |
| 518 | + assert address is not None |
| 519 | + listeners.publish_command_success( |
| 520 | + duration, |
| 521 | + response_doc, |
| 522 | + name, |
| 523 | + request_id, |
| 524 | + address, |
| 525 | + conn.server_connection_id, |
| 526 | + service_id=conn.service_id, |
| 527 | + speculative_hello=speculative_hello, |
| 528 | + database_name=dbname, |
| 529 | + ) |
| 530 | + |
| 531 | + if client and client._encrypter and reply: |
| 532 | + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) |
| 533 | + response_doc = cast( |
| 534 | + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] |
| 535 | + ) |
| 536 | + |
| 537 | + return response_doc # type: ignore[return-value] |
| 538 | + |
| 539 | + |
301 | 540 |
|
302 | 541 | async def receive_message(
|
303 | 542 | conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
|
@@ -341,3 +580,48 @@ async def receive_message(
|
341 | 580 | f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
|
342 | 581 | ) from None
|
343 | 582 | return unpack_reply(data)
|
| 583 | + |
| 584 | +async def receive_message_stream( |
| 585 | + conn: StreamReader, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE |
| 586 | +) -> Union[_OpReply, _OpMsg]: |
| 587 | + """Receive a raw BSON message or raise socket.error.""" |
| 588 | + # if _csot.get_timeout(): |
| 589 | + # deadline = _csot.get_deadline() |
| 590 | + # else: |
| 591 | + # timeout = conn.conn.gettimeout() |
| 592 | + # if timeout: |
| 593 | + # deadline = time.monotonic() + timeout |
| 594 | + # else: |
| 595 | + # deadline = None |
| 596 | + deadline = None |
| 597 | + # Ignore the response's request id. |
| 598 | + length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline)) |
| 599 | + # No request_id for exhaust cursor "getMore". |
| 600 | + if request_id is not None: |
| 601 | + if request_id != response_to: |
| 602 | + raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}") |
| 603 | + if length <= 16: |
| 604 | + raise ProtocolError( |
| 605 | + f"Message length ({length!r}) not longer than standard message header size (16)" |
| 606 | + ) |
| 607 | + if length > max_message_size: |
| 608 | + raise ProtocolError( |
| 609 | + f"Message length ({length!r}) is larger than server max " |
| 610 | + f"message size ({max_message_size!r})" |
| 611 | + ) |
| 612 | + if op_code == 2012: |
| 613 | + op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( |
| 614 | + await async_receive_data(conn, 9, deadline) |
| 615 | + ) |
| 616 | + data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id) |
| 617 | + else: |
| 618 | + data = await async_receive_data_stream(conn, length - 16, deadline) |
| 619 | + |
| 620 | + try: |
| 621 | + unpack_reply = _UNPACK_REPLY[op_code] |
| 622 | + except KeyError: |
| 623 | + raise ProtocolError( |
| 624 | + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" |
| 625 | + ) from None |
| 626 | + return unpack_reply(data) |
| 627 | + |
0 commit comments