Skip to content

Commit 4bff4fd

Browse files
committed
Standalone commands only
1 parent f45b35a commit 4bff4fd

File tree

3 files changed

+958
-16
lines changed

3 files changed

+958
-16
lines changed

pymongo/asynchronous/network.py

Lines changed: 286 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import datetime
1919
import logging
2020
import time
21+
from asyncio import streams, StreamReader
2122
from typing import (
2223
TYPE_CHECKING,
2324
Any,
@@ -45,14 +46,14 @@
4546
_UNPACK_COMPRESSION_HEADER,
4647
_UNPACK_HEADER,
4748
async_receive_data,
48-
async_sendall,
49+
async_sendall, async_sendall_stream, async_receive_data_stream,
4950
)
5051

5152
if TYPE_CHECKING:
5253
from bson import CodecOptions
5354
from pymongo.asynchronous.client_session import AsyncClientSession
5455
from pymongo.asynchronous.mongo_client import AsyncMongoClient
55-
from pymongo.asynchronous.pool import AsyncConnection
56+
from pymongo.asynchronous.pool import AsyncConnection, AsyncStreamConnection, AsyncConnectionStream
5657
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
5758
from pymongo.monitoring import _EventListeners
5859
from pymongo.read_concern import ReadConcern
@@ -114,6 +115,7 @@ async def command(
114115
bson._decode_all_selective.
115116
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
116117
"""
118+
print("Running socket command!")
117119
name = next(iter(spec))
118120
ns = dbname + ".$cmd"
119121
speculative_hello = False
@@ -298,6 +300,243 @@ async def command(
298300

299301
return response_doc # type: ignore[return-value]
300302

303+
async def command_stream(
304+
conn: AsyncConnectionStream,
305+
dbname: str,
306+
spec: MutableMapping[str, Any],
307+
is_mongos: bool,
308+
read_preference: Optional[_ServerMode],
309+
codec_options: CodecOptions[_DocumentType],
310+
session: Optional[AsyncClientSession],
311+
client: Optional[AsyncMongoClient],
312+
check: bool = True,
313+
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
314+
address: Optional[_Address] = None,
315+
listeners: Optional[_EventListeners] = None,
316+
max_bson_size: Optional[int] = None,
317+
read_concern: Optional[ReadConcern] = None,
318+
parse_write_concern_error: bool = False,
319+
collation: Optional[_CollationIn] = None,
320+
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
321+
use_op_msg: bool = False,
322+
unacknowledged: bool = False,
323+
user_fields: Optional[Mapping[str, Any]] = None,
324+
exhaust_allowed: bool = False,
325+
write_concern: Optional[WriteConcern] = None,
326+
) -> _DocumentType:
327+
"""Execute a command over the socket, or raise socket.error.
328+
329+
:param conn: a AsyncConnection instance
330+
:param dbname: name of the database on which to run the command
331+
:param spec: a command document as an ordered dict type, eg SON.
332+
:param is_mongos: are we connected to a mongos?
333+
:param read_preference: a read preference
334+
:param codec_options: a CodecOptions instance
335+
:param session: optional AsyncClientSession instance.
336+
:param client: optional AsyncMongoClient instance for updating $clusterTime.
337+
:param check: raise OperationFailure if there are errors
338+
:param allowable_errors: errors to ignore if `check` is True
339+
:param address: the (host, port) of `conn`
340+
:param listeners: An instance of :class:`~pymongo.monitoring.EventListeners`
341+
:param max_bson_size: The maximum encoded bson size for this server
342+
:param read_concern: The read concern for this command.
343+
:param parse_write_concern_error: Whether to parse the ``writeConcernError``
344+
field in the command response.
345+
:param collation: The collation for this command.
346+
:param compression_ctx: optional compression Context.
347+
:param use_op_msg: True if we should use OP_MSG.
348+
:param unacknowledged: True if this is an unacknowledged command.
349+
:param user_fields: Response fields that should be decoded
350+
using the TypeDecoders from codec_options, passed to
351+
bson._decode_all_selective.
352+
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
353+
"""
354+
# print("Running stream command!")
355+
name = next(iter(spec))
356+
ns = dbname + ".$cmd"
357+
speculative_hello = False
358+
359+
# Publish the original command document, perhaps with lsid and $clusterTime.
360+
orig = spec
361+
if is_mongos and not use_op_msg:
362+
assert read_preference is not None
363+
spec = message._maybe_add_read_preference(spec, read_preference)
364+
if read_concern and not (session and session.in_transaction):
365+
if read_concern.level:
366+
spec["readConcern"] = read_concern.document
367+
if session:
368+
session._update_read_concern(spec, conn)
369+
if collation is not None:
370+
spec["collation"] = collation
371+
372+
publish = listeners is not None and listeners.enabled_for_commands
373+
start = datetime.datetime.now()
374+
if publish:
375+
speculative_hello = _is_speculative_authenticate(name, spec)
376+
377+
if compression_ctx and name.lower() in _NO_COMPRESSION:
378+
compression_ctx = None
379+
380+
if client and client._encrypter and not client._encrypter._bypass_auto_encryption:
381+
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)
382+
383+
# Support CSOT
384+
if client:
385+
conn.apply_timeout(client, spec)
386+
_csot.apply_write_concern(spec, write_concern)
387+
388+
if use_op_msg:
389+
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
390+
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
391+
request_id, msg, size, max_doc_size = message._op_msg(
392+
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
393+
)
394+
# If this is an unacknowledged write then make sure the encoded doc(s)
395+
# are small enough, otherwise rely on the server to return an error.
396+
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
397+
message._raise_document_too_large(name, size, max_bson_size)
398+
else:
399+
request_id, msg, size = message._query(
400+
0, ns, 0, -1, spec, None, codec_options, compression_ctx
401+
)
402+
403+
if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
404+
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
405+
if client is not None:
406+
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
407+
_debug_log(
408+
_COMMAND_LOGGER,
409+
clientId=client._topology_settings._topology_id,
410+
message=_CommandStatusMessage.STARTED,
411+
command=spec,
412+
commandName=next(iter(spec)),
413+
databaseName=dbname,
414+
requestId=request_id,
415+
operationId=request_id,
416+
driverConnectionId=conn.id,
417+
serverConnectionId=conn.server_connection_id,
418+
serverHost=conn.address[0],
419+
serverPort=conn.address[1],
420+
serviceId=conn.service_id,
421+
)
422+
if publish:
423+
assert listeners is not None
424+
assert address is not None
425+
listeners.publish_command_start(
426+
orig,
427+
dbname,
428+
request_id,
429+
address,
430+
conn.server_connection_id,
431+
service_id=conn.service_id,
432+
)
433+
434+
try:
435+
await async_sendall_stream(conn.conn[1], msg)
436+
if use_op_msg and unacknowledged:
437+
# Unacknowledged, fake a successful command response.
438+
reply = None
439+
response_doc: _DocumentOut = {"ok": 1}
440+
else:
441+
reply = await receive_message_stream(conn.conn[0], request_id)
442+
conn.more_to_come = reply.more_to_come
443+
unpacked_docs = reply.unpack_response(
444+
codec_options=codec_options, user_fields=user_fields
445+
)
446+
447+
response_doc = unpacked_docs[0]
448+
if client:
449+
await client._process_response(response_doc, session)
450+
if check:
451+
helpers_shared._check_command_response(
452+
response_doc,
453+
conn.max_wire_version,
454+
allowable_errors,
455+
parse_write_concern_error=parse_write_concern_error,
456+
)
457+
except Exception as exc:
458+
duration = datetime.datetime.now() - start
459+
if isinstance(exc, (NotPrimaryError, OperationFailure)):
460+
failure: _DocumentOut = exc.details # type: ignore[assignment]
461+
else:
462+
failure = message._convert_exception(exc)
463+
if client is not None:
464+
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
465+
_debug_log(
466+
_COMMAND_LOGGER,
467+
clientId=client._topology_settings._topology_id,
468+
message=_CommandStatusMessage.FAILED,
469+
durationMS=duration,
470+
failure=failure,
471+
commandName=next(iter(spec)),
472+
databaseName=dbname,
473+
requestId=request_id,
474+
operationId=request_id,
475+
driverConnectionId=conn.id,
476+
serverConnectionId=conn.server_connection_id,
477+
serverHost=conn.address[0],
478+
serverPort=conn.address[1],
479+
serviceId=conn.service_id,
480+
isServerSideError=isinstance(exc, OperationFailure),
481+
)
482+
if publish:
483+
assert listeners is not None
484+
assert address is not None
485+
listeners.publish_command_failure(
486+
duration,
487+
failure,
488+
name,
489+
request_id,
490+
address,
491+
conn.server_connection_id,
492+
service_id=conn.service_id,
493+
database_name=dbname,
494+
)
495+
raise
496+
duration = datetime.datetime.now() - start
497+
if client is not None:
498+
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
499+
_debug_log(
500+
_COMMAND_LOGGER,
501+
clientId=client._topology_settings._topology_id,
502+
message=_CommandStatusMessage.SUCCEEDED,
503+
durationMS=duration,
504+
reply=response_doc,
505+
commandName=next(iter(spec)),
506+
databaseName=dbname,
507+
requestId=request_id,
508+
operationId=request_id,
509+
driverConnectionId=conn.id,
510+
serverConnectionId=conn.server_connection_id,
511+
serverHost=conn.address[0],
512+
serverPort=conn.address[1],
513+
serviceId=conn.service_id,
514+
speculative_authenticate="speculativeAuthenticate" in orig,
515+
)
516+
if publish:
517+
assert listeners is not None
518+
assert address is not None
519+
listeners.publish_command_success(
520+
duration,
521+
response_doc,
522+
name,
523+
request_id,
524+
address,
525+
conn.server_connection_id,
526+
service_id=conn.service_id,
527+
speculative_hello=speculative_hello,
528+
database_name=dbname,
529+
)
530+
531+
if client and client._encrypter and reply:
532+
decrypted = await client._encrypter.decrypt(reply.raw_command_response())
533+
response_doc = cast(
534+
"_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0]
535+
)
536+
537+
return response_doc # type: ignore[return-value]
538+
539+
301540

302541
async def receive_message(
303542
conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
@@ -341,3 +580,48 @@ async def receive_message(
341580
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
342581
) from None
343582
return unpack_reply(data)
583+
584+
async def receive_message_stream(
585+
conn: StreamReader, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
586+
) -> Union[_OpReply, _OpMsg]:
587+
"""Receive a raw BSON message or raise socket.error."""
588+
# if _csot.get_timeout():
589+
# deadline = _csot.get_deadline()
590+
# else:
591+
# timeout = conn.conn.gettimeout()
592+
# if timeout:
593+
# deadline = time.monotonic() + timeout
594+
# else:
595+
# deadline = None
596+
deadline = None
597+
# Ignore the response's request id.
598+
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data_stream(conn, 16, deadline))
599+
# No request_id for exhaust cursor "getMore".
600+
if request_id is not None:
601+
if request_id != response_to:
602+
raise ProtocolError(f"Got response id {response_to!r} but expected {request_id!r}")
603+
if length <= 16:
604+
raise ProtocolError(
605+
f"Message length ({length!r}) not longer than standard message header size (16)"
606+
)
607+
if length > max_message_size:
608+
raise ProtocolError(
609+
f"Message length ({length!r}) is larger than server max "
610+
f"message size ({max_message_size!r})"
611+
)
612+
if op_code == 2012:
613+
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
614+
await async_receive_data(conn, 9, deadline)
615+
)
616+
data = decompress(await async_receive_data_stream(conn, length - 25, deadline), compressor_id)
617+
else:
618+
data = await async_receive_data_stream(conn, length - 16, deadline)
619+
620+
try:
621+
unpack_reply = _UNPACK_REPLY[op_code]
622+
except KeyError:
623+
raise ProtocolError(
624+
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
625+
) from None
626+
return unpack_reply(data)
627+

0 commit comments

Comments
 (0)