Skip to content

Commit d1e1013

Browse files
authored
Only encode once (#5)
* Only encode once * WIP * Working on tests * Working on tests * checkpoint * wrapping up * max 20000 records * Move test file * Remove unused file * pylint
1 parent b2edbf8 commit d1e1013

File tree

4 files changed

+319
-205
lines changed

4 files changed

+319
-205
lines changed

stitchclient/buffer.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

stitchclient/client.py

Lines changed: 126 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,74 @@
1+
import collections
2+
from io import StringIO
13
import os
2-
import logging
3-
import requests
4-
from stitchclient.buffer import Buffer
4+
import time
55

6-
from io import StringIO
6+
import requests
77
from transit.writer import Writer
8-
from transit.reader import Reader
9-
10-
logger = logging.getLogger(__name__)
118

12-
DEFAULT_BATCH_SIZE_BYTES = 4194304
13-
DEFAULT_BATCH_DELAY_MILLIS = 60000
9+
DEFAULT_MAX_BATCH_SIZE_BYTES = 4194304
10+
DEFAULT_BATCH_DELAY_SECONDS = 60.0
11+
MAX_MESSAGES_PER_BATCH = 20000
1412
DEFAULT_STITCH_URL = 'https://api.stitchdata.com/v2/import/push'
1513

14+
class MessageTooLargeException(Exception):
15+
pass
1616

17-
class Client(object):
17+
def encode_transit(records):
18+
'''Returns the records serialized as Transit/json in utf8'''
19+
with StringIO() as buf:
20+
writer = Writer(buf, "json")
21+
writer.write(records)
22+
return buf.getvalue().encode('utf8')
23+
24+
25+
def partition_batch(entries, max_batch_size_bytes):
26+
27+
start = 0
28+
end = len(entries)
29+
result = []
30+
while start < end:
31+
32+
partitioned_entries = entries[start : end]
33+
records = [e.value for e in partitioned_entries]
34+
encoded = encode_transit(records)
1835

19-
_buffer = Buffer()
36+
if len(encoded) <= max_batch_size_bytes:
37+
result.append((encoded, [e.callback_arg for e in partitioned_entries]))
38+
39+
# If end is less than length of entries we're not done yet.
40+
# Advance start to end, and advance end by the number of
41+
# records we just put in the batch.
42+
if end < len(entries):
43+
start = end
44+
end = min(end + len(records), len(entries))
45+
46+
# If end is at the end of the input entries, we're done.
47+
else:
48+
break
49+
50+
# The size of the encoded records in our range is too large. If we
51+
# have more than one record in our range, cut the range in half
52+
# and try again.
53+
elif end - start > 1:
54+
end = start + (end - start) // 2
55+
56+
else:
57+
raise MessageTooLargeException(
58+
('A single message is larger then the maximum batch size. ' +
59+
'Message size: {}. Max batch size: {}')
60+
.format(len(encoded), max_batch_size_bytes))
61+
62+
return result
63+
64+
BufferEntry = collections.namedtuple(
65+
'BufferEntry',
66+
['value', 'callback_arg'])
67+
68+
BatchStatsEntry = collections.namedtuple(
69+
'BatchStatsEntry', ['num_records', 'num_bytes'])
70+
71+
class Client(object):
2072

2173
def __init__(self,
2274
client_id,
@@ -25,20 +77,43 @@ def __init__(self,
2577
key_names=None,
2678
callback_function=None,
2779
stitch_url=DEFAULT_STITCH_URL,
28-
batch_size_bytes=DEFAULT_BATCH_SIZE_BYTES,
29-
batch_delay_millis=DEFAULT_BATCH_DELAY_MILLIS):
80+
max_batch_size_bytes=DEFAULT_MAX_BATCH_SIZE_BYTES,
81+
batch_delay_seconds=DEFAULT_BATCH_DELAY_SECONDS):
3082

3183
assert isinstance(client_id, int), 'client_id is not an integer: {}'.format(client_id) # nopep8
3284

85+
self.max_messages_per_batch = MAX_MESSAGES_PER_BATCH
3386
self.client_id = client_id
3487
self.token = token
3588
self.table_name = table_name
3689
self.key_names = key_names
3790
self.stitch_url = stitch_url
38-
self.batch_size_bytes = batch_size_bytes
39-
self.batch_delay_millis = batch_delay_millis
91+
self.max_batch_size_bytes = max_batch_size_bytes
92+
self.batch_delay_seconds = batch_delay_seconds
4093
self.callback_function = callback_function
4194

95+
self._buffer = []
96+
97+
# Stats we update as we send records
98+
self.time_last_batch_sent = time.time()
99+
self.batch_stats = collections.deque(maxlen=100)
100+
101+
# We'll try using a big batch size to start out
102+
self.target_messages_per_batch = self.max_messages_per_batch
103+
104+
def _add_message(self, message, callback_arg):
105+
self._buffer.append(BufferEntry(value=message,
106+
callback_arg=callback_arg))
107+
108+
def moving_average_bytes_per_record(self):
109+
num_records = 0
110+
num_bytes = 0
111+
for stats in self.batch_stats:
112+
num_records += stats.num_records
113+
num_bytes += stats.num_bytes
114+
115+
return num_bytes // num_records
116+
42117
def push(self, message, callback_arg=None):
43118
"""message should be a dict recognized by the Stitch Import API.
44119
@@ -51,60 +126,69 @@ def push(self, message, callback_arg=None):
51126
message['client_id'] = self.client_id
52127
message.setdefault('table_name', self.table_name)
53128

54-
with StringIO() as s:
55-
writer = Writer(s, "json")
56-
writer.write(message)
57-
self._buffer.put(s.getvalue(), callback_arg)
129+
self._add_message(message, callback_arg)
58130

59-
batch = self._buffer.take(
60-
self.batch_size_bytes, self.batch_delay_millis)
61-
if batch is not None:
131+
batch = self._take_batch(self.target_messages_per_batch)
132+
if batch:
62133
self._send_batch(batch)
63134

64-
def _serialize_entries(self, entries):
65-
deserialized_entries = []
66-
for entry in entries:
67-
reader = Reader("json")
68-
deserialized_entries.append(reader.read(StringIO(entry.value)))
69135

70-
with StringIO() as s:
71-
writer = Writer(s, "json")
72-
writer.write(deserialized_entries)
73-
return s.getvalue()
136+
def _take_batch(self, min_records):
137+
'''If we have enough data to build a batch, returns all the data in the
138+
buffer and then clears the buffer.'''
139+
140+
if not self._buffer:
141+
return []
142+
143+
enough_messages = len(self._buffer) >= min_records
144+
enough_time = time.time() - self.time_last_batch_sent >= self.batch_delay_seconds
145+
ready = enough_messages or enough_time
146+
147+
if not ready:
148+
return []
149+
150+
result = list(self._buffer)
151+
self._buffer.clear()
152+
return result
153+
154+
def _send_batch(self, batch):
155+
for body, callback_args in partition_batch(batch, self.max_batch_size_bytes):
156+
self._send(body, callback_args)
157+
158+
self.target_messages_per_batch = min(self.max_messages_per_batch,
159+
0.8 * (self.max_batch_size_bytes / self.moving_average_bytes_per_record()))
160+
74161

75162
def _stitch_request(self, body):
76163
headers = {'Authorization': 'Bearer {}'.format(self.token),
77164
'Content-Type': 'application/transit+json'}
78165
return requests.post(self.stitch_url, headers=headers, data=body)
79166

80-
def _send_batch(self, batch):
81-
logger.debug("Sending batch of %s entries", len(batch))
82-
body = self._serialize_entries(batch).encode('utf8')
167+
168+
def _send(self, body, callback_args):
83169
response = self._stitch_request(body)
84170

85171
if response.status_code < 300:
86172
if self.callback_function is not None:
87-
self.callback_function([x.callback_arg for x in batch])
173+
self.callback_function(callback_args)
88174
else:
89175
raise RuntimeError("Error sending data to the Stitch API. {0.status_code} - {0.content}" # nopep8
90176
.format(response))
177+
self.time_last_batch_sent = time.time()
178+
self.batch_stats.append(BatchStatsEntry(len(callback_args), len(body)))
91179

92180
def flush(self):
93-
while True:
94-
batch = self._buffer.take(0, 0)
95-
if batch is None:
96-
return
97-
98-
self._send_batch(batch)
181+
batch = self._take_batch(0)
182+
self._send_batch(batch)
99183

100184
def __enter__(self):
101185
return self
102186

103187
def __exit__(self, exception_type, exception_value, traceback):
104188
self.flush()
105189

190+
106191
if __name__ == "__main__":
107-
logging.basicConfig(level=logging.DEBUG)
108192

109193
with Client(int(os.environ['STITCH_CLIENT_ID']),
110194
os.environ['STITCH_TOKEN'],

tests/test_buffer.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

0 commit comments

Comments
 (0)