Skip to content

Commit a92d7cc

Browse files
authored
Add tests (#3)
* Add tests * Move buffer into its own module * Move test * Pep8 * Pep8
1 parent 7a2cf9e commit a92d7cc

File tree

4 files changed

+176
-53
lines changed

4 files changed

+176
-53
lines changed

stitchclient/buffer.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import time
2+
import collections
3+
4+
from collections import deque
5+
6+
BufferEntry = collections.namedtuple(
7+
'BufferEntry',
8+
'timestamp value callback_arg')
9+
10+
MAX_BATCH_SIZE_BYTES = 4194304
11+
MAX_MESSAGES_PER_BATCH = 10000
12+
13+
14+
class Buffer(object):
15+
16+
def __init__(self):
17+
self._queue = deque()
18+
self._available_bytes = 0
19+
20+
def put(self, value, callback_arg):
21+
# We need two extra bytes for the [ and ] wrapping the record.
22+
max_len = MAX_BATCH_SIZE_BYTES - 2
23+
24+
if len(value) > max_len:
25+
raise ValueError(
26+
"Can't accept a record larger than {} bytes".format(max_len))
27+
28+
self._queue.append(BufferEntry(timestamp=time.time()*1000,
29+
value=value,
30+
callback_arg=callback_arg))
31+
self._available_bytes += len(value.encode("utf8"))
32+
33+
def take(self, batch_size_bytes, batch_delay_millis):
34+
if len(self._queue) == 0:
35+
return None
36+
37+
t = time.time() * 1000
38+
t0 = self._queue[0].timestamp
39+
enough_bytes = self._available_bytes >= batch_size_bytes
40+
enough_messages = len(self._queue) >= MAX_MESSAGES_PER_BATCH
41+
enough_time = t - t0 >= batch_delay_millis
42+
ready = enough_bytes or enough_messages or enough_time
43+
44+
if not ready:
45+
return None
46+
47+
entries = []
48+
size = 2
49+
50+
while (len(self._queue) > 0 and
51+
size + len(self._queue[0].value.encode("utf8")) <
52+
MAX_BATCH_SIZE_BYTES):
53+
entry = self._queue.popleft()
54+
55+
# add one for the comma that will be needed to link entries
56+
# together
57+
entry_size = len(entry.value.encode("utf8"))
58+
size += entry_size + 1
59+
self._available_bytes -= entry_size
60+
entries.append(entry)
61+
62+
return entries

stitchclient/client.py

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import os
2-
import time
32
import logging
4-
import collections
5-
63
import requests
4+
from .buffer import Buffer
75

8-
from collections import deque
96
from io import StringIO
107
from transit.writer import Writer
118
from transit.reader import Reader
@@ -16,48 +13,9 @@
1613
DEFAULT_BATCH_DELAY_MILLIS = 60000
1714
DEFAULT_STITCH_URL = 'https://api.stitchdata.com/v2/import/push'
1815

19-
BufferEntry = collections.namedtuple('BufferEntry', 'timestamp value callback_arg')
2016

2117
class Client(object):
2218

23-
class Buffer(object):
24-
25-
MAX_BATCH_SIZE_BYTES = 4194304
26-
MAX_MESSAGES_PER_BATCH = 10000
27-
28-
_queue = deque()
29-
_available_bytes = 0
30-
31-
def put(self, value, callback_arg):
32-
self._queue.append(BufferEntry(timestamp=time.time()*1000, value=value, callback_arg=callback_arg))
33-
self._available_bytes += len(value.encode("utf8"))
34-
35-
def take(self, batch_size_bytes, batch_delay_millis):
36-
if len(self._queue) == 0:
37-
return None
38-
39-
ready = self._available_bytes >= batch_size_bytes or \
40-
len(self._queue) >= self.MAX_MESSAGES_PER_BATCH or \
41-
time.time()*1000 - self._queue[0].timestamp >= batch_delay_millis
42-
43-
if not ready:
44-
return None
45-
46-
entries = []
47-
size = 2
48-
49-
while len(self._queue) > 0 and \
50-
size + len(self._queue[0].value.encode("utf8")) < self.MAX_BATCH_SIZE_BYTES:
51-
entry = self._queue.popleft()
52-
53-
# add one for the comma that will be needed to link entries together
54-
entry_size = len(entry.value.encode("utf8"))
55-
size += entry_size + 1
56-
self._available_bytes -= entry_size
57-
entries.append(entry)
58-
59-
return entries
60-
6119
_buffer = Buffer()
6220

6321
def __init__(self,
@@ -70,7 +28,8 @@ def __init__(self,
7028
batch_size_bytes=DEFAULT_BATCH_SIZE_BYTES,
7129
batch_delay_millis=DEFAULT_BATCH_DELAY_MILLIS):
7230

73-
assert isinstance(client_id, int), 'client_id is not an integer: {}'.format(client_id)
31+
assert(isinstance(client_id, int),
32+
'client_id is not an integer: {}'.format(client_id))
7433

7534
self.client_id = client_id
7635
self.token = token
@@ -81,7 +40,7 @@ def __init__(self,
8140
self.batch_delay_millis = batch_delay_millis
8241
self.callback_function = callback_function
8342

84-
def push(self, message, callback_arg = None):
43+
def push(self, message, callback_arg=None):
8544
"""
8645
message must be a dict with at least these keys:
8746
action, table_name, key_names, sequence, data
@@ -91,10 +50,8 @@ def push(self, message, callback_arg = None):
9150

9251
if message['action'] == 'upsert':
9352
message.setdefault('key_names', self.key_names)
94-
elif message['action'] == 'switch_view':
95-
pass
9653
else:
97-
raise ValueError('Message action property must be one of: "upsert", "switch_view"')
54+
raise ValueError('Message action property must be "upsert"')
9855

9956
message['client_id'] = self.client_id
10057
message.setdefault('table_name', self.table_name)
@@ -104,7 +61,8 @@ def push(self, message, callback_arg = None):
10461
writer.write(message)
10562
self._buffer.put(s.getvalue(), callback_arg)
10663

107-
batch = self._buffer.take(self.batch_size_bytes, self.batch_delay_millis)
64+
batch = self._buffer.take(
65+
self.batch_size_bytes, self.batch_delay_millis)
10866
if batch is not None:
10967
self._send_batch(batch)
11068

@@ -133,12 +91,12 @@ def _send_batch(self, batch):
13391
if self.callback_function is not None:
13492
self.callback_function([x.callback_arg for x in batch])
13593
else:
136-
raise RuntimeError("Error sending data to the Stitch API. {0.status_code} - {0.content}"
94+
raise RuntimeError("Error sending data to the Stitch API. {0.status_code} - {0.content}" # nopep8
13795
.format(response))
13896

13997
def flush(self):
14098
while True:
141-
batch = self._buffer.take(0,0)
99+
batch = self._buffer.take(0, 0)
142100
if batch is None:
143101
return
144102

@@ -153,8 +111,10 @@ def __exit__(self, exception_type, exception_value, traceback):
153111
if __name__ == "__main__":
154112
logging.basicConfig(level=logging.DEBUG)
155113

156-
with Client(int(os.environ['STITCH_CLIENT_ID']), os.environ['STITCH_TOKEN'], callback_function=print) as c:
157-
for i in range(1,10):
114+
with Client(int(os.environ['STITCH_CLIENT_ID']),
115+
os.environ['STITCH_TOKEN'],
116+
callback_function=print) as c:
117+
for i in range(1, 10):
158118
c.push({'action': 'upsert',
159119
'table_name': 'test_table',
160120
'key_names': ['id'],

tests/__init__.py

Whitespace-only changes.

tests/test_buffer.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import unittest
2+
import stitchclient.client
3+
import time
4+
from stitchclient.buffer import MAX_BATCH_SIZE_BYTES
5+
6+
class TestTargetStitch(unittest.TestCase):
7+
8+
client_id = 1
9+
token = 'asdf'
10+
table_name = 'users'
11+
key_names = ['id']
12+
13+
def test_push(self):
14+
pass
15+
# client = stitchclient.Client(
16+
# client_id=client_id,
17+
# token=token,
18+
# table_name=table_name,
19+
# key_names=key_names,
20+
# batch_size_bytes=1000,
21+
# batch_delay_millis=100000)
22+
23+
24+
tiny_record = 'apple'
25+
big_record = 'a' * 1500000
26+
huge_record = 'a' * 5000000
27+
28+
class TestBuffer(unittest.TestCase):
29+
30+
buffer = None
31+
32+
33+
def test_single_record_available_immediately(self):
34+
buf = stitchclient.client.Buffer()
35+
buf.put(tiny_record, None)
36+
self.assertEqual(buf.take(0, 0)[0].value,
37+
tiny_record)
38+
39+
def test_withhold_until_bytes_available(self):
40+
buf = stitchclient.client.Buffer()
41+
batch_size_bytes = int(len(tiny_record) * 2 + len(tiny_record) / 2.0)
42+
batch_delay_millis = 1000000000
43+
put = lambda: buf.put(tiny_record, None)
44+
take = lambda: buf.take(batch_size_bytes, batch_delay_millis)
45+
46+
put()
47+
self.assertTrue(take() is None)
48+
put()
49+
self.assertTrue(take() is None)
50+
put()
51+
res = take()
52+
self.assertEqual([x.value for x in res], ['apple', 'apple', 'apple'])
53+
54+
def test_buffer_empty_after_batch(self):
55+
buf = stitchclient.client.Buffer()
56+
put = lambda: buf.put(tiny_record, None)
57+
take = lambda: buf.take(0, 0)
58+
put()
59+
put()
60+
put()
61+
self.assertIsNotNone(take())
62+
self.assertIsNone(take())
63+
64+
def test_does_not_exceed_max_batch_size(self):
65+
buf = stitchclient.client.Buffer()
66+
put = lambda: buf.put(big_record, None)
67+
take = lambda: buf.take(0, 0)
68+
69+
put()
70+
put()
71+
put()
72+
73+
b1 = take()
74+
b2 = take()
75+
b3 = take()
76+
77+
b1len = 0
78+
b2len = 0
79+
80+
for x in b1:
81+
b1len += len(x.value)
82+
for x in b2:
83+
b2len += len(x.value)
84+
self.assertTrue(b1len < MAX_BATCH_SIZE_BYTES)
85+
self.assertTrue(b2len < b1len)
86+
self.assertIsNone(b3)
87+
88+
def test_cant_put_record_larger_than_max_message_size(self):
89+
buf = stitchclient.client.Buffer()
90+
with self.assertRaises(ValueError):
91+
buf.put(huge_record, None)
92+
93+
def test_trigger_batch_at_10k_messages(self):
94+
buf = stitchclient.client.Buffer()
95+
put = lambda: buf.put(tiny_record, None)
96+
take = lambda: buf.take(MAX_BATCH_SIZE_BYTES, 60000)
97+
for i in range(9999):
98+
put()
99+
self.assertTrue(take() is None)
100+
put()
101+
self.assertIsNotNone(take())

0 commit comments

Comments
 (0)