1+ import collections
2+ from io import StringIO
13import os
2- import logging
3- import requests
4- from stitchclient .buffer import Buffer
4+ import time
55
6- from io import StringIO
6+ import requests
77from transit .writer import Writer
8- from transit .reader import Reader
9-
10- logger = logging .getLogger (__name__ )
118
12- DEFAULT_BATCH_SIZE_BYTES = 4194304
13- DEFAULT_BATCH_DELAY_MILLIS = 60000
9+ DEFAULT_MAX_BATCH_SIZE_BYTES = 4194304
10+ DEFAULT_BATCH_DELAY_SECONDS = 60.0
11+ MAX_MESSAGES_PER_BATCH = 20000
1412DEFAULT_STITCH_URL = 'https://api.stitchdata.com/v2/import/push'
1513
14+ class MessageTooLargeException (Exception ):
15+ pass
1616
17- class Client (object ):
17+ def encode_transit (records ):
18+ '''Returns the records serialized as Transit/json in utf8'''
19+ with StringIO () as buf :
20+ writer = Writer (buf , "json" )
21+ writer .write (records )
22+ return buf .getvalue ().encode ('utf8' )
23+
24+
25+ def partition_batch (entries , max_batch_size_bytes ):
26+
27+ start = 0
28+ end = len (entries )
29+ result = []
30+ while start < end :
31+
32+ partitioned_entries = entries [start : end ]
33+ records = [e .value for e in partitioned_entries ]
34+ encoded = encode_transit (records )
1835
19- _buffer = Buffer ()
36+ if len (encoded ) <= max_batch_size_bytes :
37+ result .append ((encoded , [e .callback_arg for e in partitioned_entries ]))
38+
39+ # If end is less than length of entries we're not done yet.
40+ # Advance start to end, and advance end by the number of
41+ # records we just put in the batch.
42+ if end < len (entries ):
43+ start = end
44+ end = min (end + len (records ), len (entries ))
45+
46+ # If end is at the end of the input entries, we're done.
47+ else :
48+ break
49+
50+ # The size of the encoded records in our range is too large. If we
51+ # have more than one record in our range, cut the range in half
52+ # and try again.
53+ elif end - start > 1 :
54+ end = start + (end - start ) // 2
55+
56+ else :
57+ raise MessageTooLargeException (
58+ ('A single message is larger then the maximum batch size. ' +
59+ 'Message size: {}. Max batch size: {}' )
60+ .format (len (encoded ), max_batch_size_bytes ))
61+
62+ return result
63+
64+ BufferEntry = collections .namedtuple (
65+ 'BufferEntry' ,
66+ ['value' , 'callback_arg' ])
67+
68+ BatchStatsEntry = collections .namedtuple (
69+ 'BatchStatsEntry' , ['num_records' , 'num_bytes' ])
70+
71+ class Client (object ):
2072
2173 def __init__ (self ,
2274 client_id ,
@@ -25,20 +77,43 @@ def __init__(self,
2577 key_names = None ,
2678 callback_function = None ,
2779 stitch_url = DEFAULT_STITCH_URL ,
28- batch_size_bytes = DEFAULT_BATCH_SIZE_BYTES ,
29- batch_delay_millis = DEFAULT_BATCH_DELAY_MILLIS ):
80+ max_batch_size_bytes = DEFAULT_MAX_BATCH_SIZE_BYTES ,
81+ batch_delay_seconds = DEFAULT_BATCH_DELAY_SECONDS ):
3082
3183 assert isinstance (client_id , int ), 'client_id is not an integer: {}' .format (client_id ) # nopep8
3284
85+ self .max_messages_per_batch = MAX_MESSAGES_PER_BATCH
3386 self .client_id = client_id
3487 self .token = token
3588 self .table_name = table_name
3689 self .key_names = key_names
3790 self .stitch_url = stitch_url
38- self .batch_size_bytes = batch_size_bytes
39- self .batch_delay_millis = batch_delay_millis
91+ self .max_batch_size_bytes = max_batch_size_bytes
92+ self .batch_delay_seconds = batch_delay_seconds
4093 self .callback_function = callback_function
4194
95+ self ._buffer = []
96+
97+ # Stats we update as we send records
98+ self .time_last_batch_sent = time .time ()
99+ self .batch_stats = collections .deque (maxlen = 100 )
100+
101+ # We'll try using a big batch size to start out
102+ self .target_messages_per_batch = self .max_messages_per_batch
103+
104+ def _add_message (self , message , callback_arg ):
105+ self ._buffer .append (BufferEntry (value = message ,
106+ callback_arg = callback_arg ))
107+
108+ def moving_average_bytes_per_record (self ):
109+ num_records = 0
110+ num_bytes = 0
111+ for stats in self .batch_stats :
112+ num_records += stats .num_records
113+ num_bytes += stats .num_bytes
114+
115+ return num_bytes // num_records
116+
42117 def push (self , message , callback_arg = None ):
43118 """message should be a dict recognized by the Stitch Import API.
44119
@@ -51,60 +126,69 @@ def push(self, message, callback_arg=None):
51126 message ['client_id' ] = self .client_id
52127 message .setdefault ('table_name' , self .table_name )
53128
54- with StringIO () as s :
55- writer = Writer (s , "json" )
56- writer .write (message )
57- self ._buffer .put (s .getvalue (), callback_arg )
129+ self ._add_message (message , callback_arg )
58130
59- batch = self ._buffer .take (
60- self .batch_size_bytes , self .batch_delay_millis )
61- if batch is not None :
131+ batch = self ._take_batch (self .target_messages_per_batch )
132+ if batch :
62133 self ._send_batch (batch )
63134
64- def _serialize_entries (self , entries ):
65- deserialized_entries = []
66- for entry in entries :
67- reader = Reader ("json" )
68- deserialized_entries .append (reader .read (StringIO (entry .value )))
69135
70- with StringIO () as s :
71- writer = Writer (s , "json" )
72- writer .write (deserialized_entries )
73- return s .getvalue ()
136+ def _take_batch (self , min_records ):
137+ '''If we have enough data to build a batch, returns all the data in the
138+ buffer and then clears the buffer.'''
139+
140+ if not self ._buffer :
141+ return []
142+
143+ enough_messages = len (self ._buffer ) >= min_records
144+ enough_time = time .time () - self .time_last_batch_sent >= self .batch_delay_seconds
145+ ready = enough_messages or enough_time
146+
147+ if not ready :
148+ return []
149+
150+ result = list (self ._buffer )
151+ self ._buffer .clear ()
152+ return result
153+
154+ def _send_batch (self , batch ):
155+ for body , callback_args in partition_batch (batch , self .max_batch_size_bytes ):
156+ self ._send (body , callback_args )
157+
158+ self .target_messages_per_batch = min (self .max_messages_per_batch ,
159+ 0.8 * (self .max_batch_size_bytes / self .moving_average_bytes_per_record ()))
160+
74161
75162 def _stitch_request (self , body ):
76163 headers = {'Authorization' : 'Bearer {}' .format (self .token ),
77164 'Content-Type' : 'application/transit+json' }
78165 return requests .post (self .stitch_url , headers = headers , data = body )
79166
80- def _send_batch (self , batch ):
81- logger .debug ("Sending batch of %s entries" , len (batch ))
82- body = self ._serialize_entries (batch ).encode ('utf8' )
167+
168+ def _send (self , body , callback_args ):
83169 response = self ._stitch_request (body )
84170
85171 if response .status_code < 300 :
86172 if self .callback_function is not None :
87- self .callback_function ([ x . callback_arg for x in batch ] )
173+ self .callback_function (callback_args )
88174 else :
89175 raise RuntimeError ("Error sending data to the Stitch API. {0.status_code} - {0.content}" # nopep8
90176 .format (response ))
177+ self .time_last_batch_sent = time .time ()
178+ self .batch_stats .append (BatchStatsEntry (len (callback_args ), len (body )))
91179
92180 def flush (self ):
93- while True :
94- batch = self ._buffer .take (0 , 0 )
95- if batch is None :
96- return
97-
98- self ._send_batch (batch )
181+ batch = self ._take_batch (0 )
182+ self ._send_batch (batch )
99183
100184 def __enter__ (self ):
101185 return self
102186
103187 def __exit__ (self , exception_type , exception_value , traceback ):
104188 self .flush ()
105189
190+
106191if __name__ == "__main__" :
107- logging .basicConfig (level = logging .DEBUG )
108192
109193 with Client (int (os .environ ['STITCH_CLIENT_ID' ]),
110194 os .environ ['STITCH_TOKEN' ],
0 commit comments