diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index c200899dd1..9fd673693f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -281,6 +281,7 @@ async def write_command( ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + await client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -308,6 +309,9 @@ async def write_command( if bwc.publish: bwc._fail(request_id, failure, duration) + # Process the response from the server. + if isinstance(exc, (NotPrimaryError, OperationFailure)): + await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise finally: bwc.start_time = datetime.datetime.now() @@ -449,7 +453,6 @@ async def _execute_batch( else: request_id, msg, to_send = bwc.batch_command(cmd, ops) result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] - await client._process_response(result, bwc.session) # type: ignore[arg-type] return result, to_send # type: ignore[return-value] diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index b9ab6b876b..15a0369f41 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -283,6 +283,8 @@ async def write_command( ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + # Process the response from the server. + await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -312,6 +314,11 @@ async def write_command( bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} + # Process the response from the server. + if isinstance(exc, OperationFailure): + await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + else: + await self.client._process_response({}, bwc.session) # type: ignore[arg-type] finally: bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] @@ -431,7 +438,6 @@ async def _execute_batch( result = await self.write_command( bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client ) # type: ignore[arg-type] - await self.client._process_response(result, bwc.session) # type: ignore[arg-type] return result, to_send_ops, to_send_ns # type: ignore[return-value] async def _process_results_cursor( diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 4da64c4a78..27fcff620c 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -281,6 +281,7 @@ def write_command( ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -308,6 +309,9 @@ def write_command( if bwc.publish: bwc._fail(request_id, failure, duration) + # Process the response from the server. + if isinstance(exc, (NotPrimaryError, OperationFailure)): + client._process_response(exc.details, bwc.session) # type: ignore[arg-type] raise finally: bwc.start_time = datetime.datetime.now() @@ -449,7 +453,6 @@ def _execute_batch( else: request_id, msg, to_send = bwc.batch_command(cmd, ops) result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] - client._process_response(result, bwc.session) # type: ignore[arg-type] return result, to_send # type: ignore[return-value] diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 106e5dcbb3..23af231d16 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -283,6 +283,8 @@ def write_command( ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + # Process the response from the server. + self.client._process_response(reply, bwc.session) # type: ignore[arg-type] except Exception as exc: duration = datetime.datetime.now() - bwc.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -312,6 +314,11 @@ def write_command( bwc._fail(request_id, failure, duration) # Top-level error will be embedded in ClientBulkWriteException. reply = {"error": exc} + # Process the response from the server. + if isinstance(exc, OperationFailure): + self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + else: + self.client._process_response({}, bwc.session) # type: ignore[arg-type] finally: bwc.start_time = datetime.datetime.now() return reply # type: ignore[return-value] @@ -429,7 +436,6 @@ def _execute_batch( """Executes a batch of bulkWrite server commands (ack).""" request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces) result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type] - self.client._process_response(result, bwc.session) # type: ignore[arg-type] return result, to_send_ops, to_send_ns # type: ignore[return-value] def _process_results_cursor( diff --git a/test/mockupdb/test_cluster_time.py b/test/mockupdb/test_cluster_time.py index f3ab0a6c54..9794843175 100644 --- a/test/mockupdb/test_cluster_time.py +++ b/test/mockupdb/test_cluster_time.py @@ -29,21 +29,22 @@ from bson import Timestamp from pymongo import DeleteMany, InsertOne, MongoClient, UpdateOne +from pymongo.errors import OperationFailure pytestmark = pytest.mark.mockupdb class TestClusterTime(unittest.TestCase): - def cluster_time_conversation(self, callback, replies): + def cluster_time_conversation(self, callback, replies, max_wire_version=6): cluster_time = Timestamp(0, 0) server = MockupDB() - # First test all commands include $clusterTime with wire version 6. + # First test all commands include $clusterTime with max_wire_version. _ = server.autoresponds( "ismaster", { "minWireVersion": 0, - "maxWireVersion": 6, + "maxWireVersion": max_wire_version, "$clusterTime": {"clusterTime": cluster_time}, }, ) @@ -166,6 +167,30 @@ def test_monitor(self): request.reply(reply) client.close() + def test_collection_bulk_error(self): + def callback(client: MongoClient[dict]) -> None: + with self.assertRaises(OperationFailure): + client.db.collection.bulk_write([InsertOne({}), InsertOne({})]) + + self.cluster_time_conversation( + callback, + [{"ok": 0, "errmsg": "mock error"}], + ) + + def test_client_bulk_error(self): + def callback(client: MongoClient[dict]) -> None: + with self.assertRaises(OperationFailure): + client.bulk_write( + [ + InsertOne({}, namespace="db.collection"), + InsertOne({}, namespace="db.collection"), + ] + ) + + self.cluster_time_conversation( + callback, [{"ok": 0, "errmsg": "mock error"}], max_wire_version=25 + ) + if __name__ == "__main__": unittest.main()