Skip to content

Commit 6cca1ec

Browse files
Add a max_size argument to Pub / Sub Batch. (#3157)
1 parent 45c6a0a commit 6cca1ec

File tree

2 files changed

+98
-30
lines changed

2 files changed

+98
-30
lines changed

pubsub/google/cloud/pubsub/topic.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414

1515
"""Define API Topics."""
1616

17+
import base64
18+
import json
1719
import time
1820

1921
from google.cloud._helpers import _datetime_to_rfc3339
2022
from google.cloud._helpers import _NOW
23+
from google.cloud._helpers import _to_bytes
2124
from google.cloud.exceptions import NotFound
2225
from google.cloud.pubsub._helpers import topic_name_from_path
2326
from google.cloud.pubsub.iam import Policy
@@ -255,7 +258,7 @@ def publish(self, message, client=None, **attrs):
255258
message_ids = api.topic_publish(self.full_name, [message_data])
256259
return message_ids[0]
257260

258-
def batch(self, client=None):
261+
def batch(self, client=None, **kwargs):
259262
"""Return a batch to use as a context manager.
260263
261264
Example:
@@ -275,11 +278,15 @@ def batch(self, client=None):
275278
:param client: the client to use. If not passed, falls back to the
276279
``client`` stored on the current topic.
277280
281+
:type kwargs: dict
282+
:param kwargs: Keyword arguments passed to the
283+
:class:`~google.cloud.pubsub.topic.Batch` constructor.
284+
278285
:rtype: :class:`Batch`
279286
:returns: A batch to use as a context manager.
280287
"""
281288
client = self._require_client(client)
282-
return Batch(self, client)
289+
return Batch(self, client, **kwargs)
283290

284291
def list_subscriptions(self, page_size=None, page_token=None, client=None):
285292
"""List subscriptions for the project associated with this client.
@@ -426,11 +433,16 @@ class Batch(object):
426433
before automatically commiting. Defaults to infinity
427434
(off).
428435
:type max_messages: float
436+
437+
:param max_size: The maximum size that the serialized messages can be
438+
before automatically commiting. Defaults to 9 MB
439+
(slightly less than the API limit).
440+
:type max_size: int
429441
"""
430442
_INFINITY = float('inf')
431443

432444
def __init__(self, topic, client, max_interval=_INFINITY,
433-
max_messages=_INFINITY):
445+
max_messages=_INFINITY, max_size=1024 * 1024 * 9):
434446
self.topic = topic
435447
self.messages = []
436448
self.message_ids = []
@@ -440,9 +452,12 @@ def __init__(self, topic, client, max_interval=_INFINITY,
440452
# is exceeded, then the .publish() method will imply a commit.
441453
self._max_interval = max_interval
442454
self._max_messages = max_messages
455+
self._max_size = max_size
443456

444-
# Set the initial starting timestamp (used against the interval).
457+
# Set the initial starting timestamp (used against the interval)
458+
# and initial size.
445459
self._start_timestamp = time.time()
460+
self._current_size = 0
446461

447462
def __enter__(self):
448463
return self
@@ -464,16 +479,24 @@ def publish(self, message, **attrs):
464479
:param attrs: key-value pairs to send as message attributes
465480
"""
466481
self.topic._timestamp_message(attrs)
467-
self.messages.append(
468-
{'data': message,
469-
'attributes': attrs})
482+
483+
# Append the message to the list of messages..
484+
item = {'attributes': attrs, 'data': message}
485+
self.messages.append(item)
486+
487+
# Determine the approximate size of the message, and increment
488+
# the current batch size appropriately.
489+
encoded = base64.b64encode(_to_bytes(message))
490+
encoded += base64.b64encode(
491+
json.dumps(attrs, ensure_ascii=False).encode('utf8'),
492+
)
493+
self._current_size += len(encoded)
470494

471495
# If too much time has elapsed since the first message
472496
# was added, autocommit.
473497
now = time.time()
474498
if now - self._start_timestamp > self._max_interval:
475499
self.commit()
476-
self._start_timestamp = now
477500
return
478501

479502
# If the number of messages on the list is greater than the
@@ -482,6 +505,11 @@ def publish(self, message, **attrs):
482505
self.commit()
483506
return
484507

508+
# If we have reached the max size, autocommit.
509+
if self._current_size >= self._max_size:
510+
self.commit()
511+
return
512+
485513
def commit(self, client=None):
486514
"""Send saved messages as a single API call.
487515
@@ -499,3 +527,5 @@ def commit(self, client=None):
499527
message_ids = api.topic_publish(self.topic.full_name, self.messages[:])
500528
self.message_ids.extend(message_ids)
501529
del self.messages[:]
530+
self._start_timestamp = time.time()
531+
self._current_size = 0

pubsub/unit_tests/test_topic.py

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -779,10 +779,35 @@ def test_context_mgr_failure(self):
779779
self.assertEqual(list(batch.messages), [MESSAGE1, MESSAGE2])
780780
self.assertEqual(getattr(api, '_topic_published', self), self)
781781

782+
def test_batch_messages(self):
783+
# Establish that a batch actually batches messsages in the expected
784+
# way.
785+
client = _Client(project='PROJECT')
786+
topic = _Topic(name='TOPIC')
787+
788+
# Track commits, but do not perform them.
789+
Batch = self._get_target_class()
790+
with mock.patch.object(Batch, 'commit') as commit:
791+
with self._make_one(topic, client=client) as batch:
792+
self.assertIsInstance(batch, Batch)
793+
794+
# Publish four messages and establish that the batch does
795+
# not commit.
796+
for i in range(0, 4):
797+
batch.publish('Batch message %d.' % (i,))
798+
commit.assert_not_called()
799+
800+
# Check the contents of the batch.
801+
self.assertEqual(batch.messages, [
802+
{'data': 'Batch message 0.', 'attributes': {}},
803+
{'data': 'Batch message 1.', 'attributes': {}},
804+
{'data': 'Batch message 2.', 'attributes': {}},
805+
{'data': 'Batch message 3.', 'attributes': {}},
806+
])
807+
782808
def test_message_count_autocommit(self):
783-
"""Establish that if the batch is assigned to take a maximum
784-
number of messages, that it commits when it reaches that maximum.
785-
"""
809+
# Establish that if the batch is assigned to take a maximum
810+
# number of messages, that it commits when it reaches that maximum.
786811
client = _Client(project='PROJECT')
787812
topic = _Topic(name='TOPIC')
788813

@@ -795,17 +820,11 @@ def test_message_count_autocommit(self):
795820
# Publish four messages and establish that the batch does
796821
# not commit.
797822
for i in range(0, 4):
798-
batch.publish({
799-
'attributes': {},
800-
'data': 'Batch message %d.' % (i,),
801-
})
823+
batch.publish('Batch message %d.' % (i,))
802824
commit.assert_not_called()
803825

804826
# Publish a fifth message and observe the commit.
805-
batch.publish({
806-
'attributes': {},
807-
'data': 'The final call to trigger a commit!',
808-
})
827+
batch.publish('The final call to trigger a commit!')
809828
commit.assert_called_once_with()
810829

811830
# There should be a second commit after the context manager
@@ -814,9 +833,8 @@ def test_message_count_autocommit(self):
814833

815834
@mock.patch('time.time')
816835
def test_message_time_autocommit(self, mock_time):
817-
"""Establish that if the batch is sufficiently old, that it commits
818-
the next time it receives a publish.
819-
"""
836+
# Establish that if the batch is sufficiently old, that it commits
837+
# the next time it receives a publish.
820838
client = _Client(project='PROJECT')
821839
topic = _Topic(name='TOPIC')
822840

@@ -830,20 +848,40 @@ def test_message_time_autocommit(self, mock_time):
830848
# Publish some messages and establish that the batch does
831849
# not commit.
832850
for i in range(0, 10):
833-
batch.publish({
834-
'attributes': {},
835-
'data': 'Batch message %d.' % (i,),
836-
})
851+
batch.publish('Batch message %d.' % (i,))
837852
commit.assert_not_called()
838853

839854
# Move time ahead so that this batch is too old.
840855
mock_time.return_value = 10.0
841856

842857
# Publish another message and observe the commit.
843-
batch.publish({
844-
'attributes': {},
845-
'data': 'The final call to trigger a commit!',
846-
})
858+
batch.publish('The final call to trigger a commit!')
859+
commit.assert_called_once_with()
860+
861+
# There should be a second commit after the context manager
862+
# exits.
863+
self.assertEqual(commit.call_count, 2)
864+
865+
def test_message_size_autocommit(self):
866+
# Establish that if the batch is sufficiently large, that it
867+
# auto-commits.
868+
client = _Client(project='PROJECT')
869+
topic = _Topic(name='TOPIC')
870+
871+
# Track commits, but do not perform them.
872+
Batch = self._get_target_class()
873+
with mock.patch.object(Batch, 'commit') as commit:
874+
with self._make_one(topic, client=client, max_size=100) as batch:
875+
self.assertIsInstance(batch, Batch)
876+
877+
# Publish a short (< 100 bytes) message and establish that
878+
# the batch does not commit.
879+
batch.publish(b'foo')
880+
commit.assert_not_called()
881+
882+
# Publish another message and observe the commit.
883+
batch.publish(u'The final call to trigger a commit, because '
884+
u'this message is sufficiently long.')
847885
commit.assert_called_once_with()
848886

849887
# There should be a second commit after the context manager

0 commit comments

Comments
 (0)