44import io
55import itertools
66import logging
7+ import math
78import socket
89from contextlib import contextmanager
910from typing import Any , BinaryIO , Dict , Iterator , List , Optional , Set , Tuple , Union , cast
1011
1112import boto3
12- from botocore .exceptions import ClientError
13+ from botocore .exceptions import ClientError , ReadTimeoutError
1314
1415from awswrangler import _utils , exceptions
1516from awswrangler ._config import apply_configs
1617from awswrangler .s3 ._describe import size_objects
1718
1819_logger : logging .Logger = logging .getLogger (__name__ )
1920
20- _S3_RETRYABLE_ERRORS : Tuple [Any , Any ] = (socket .timeout , ConnectionError )
21+ _S3_RETRYABLE_ERRORS : Tuple [Any , Any , Any ] = (socket .timeout , ConnectionError , ReadTimeoutError )
2122
2223_MIN_WRITE_BLOCK : int = 5_242_880 # 5 MB (5 * 2**20)
2324_MIN_PARALLEL_READ_BLOCK : int = 5_242_880 # 5 MB (5 * 2**20)
@@ -178,14 +179,15 @@ class _S3Object: # pylint: disable=too-many-instance-attributes
178179 def __init__ (
179180 self ,
180181 path : str ,
181- s3_read_ahead_size : int ,
182+ s3_block_size : int ,
182183 mode : str ,
183184 use_threads : bool ,
184185 s3_additional_kwargs : Optional [Dict [str , str ]],
185186 boto3_session : Optional [boto3 .Session ],
186187 newline : Optional [str ],
187188 encoding : Optional [str ],
188189 ) -> None :
190+ self .closed : bool = False
189191 self ._use_threads = use_threads
190192 self ._newline : str = "\n " if newline is None else newline
191193 self ._encoding : str = "utf-8" if encoding is None else encoding
@@ -194,11 +196,13 @@ def __init__(
194196 if mode not in {"rb" , "wb" , "r" , "w" }:
195197 raise NotImplementedError ("File mode must be {'rb', 'wb', 'r', 'w'}, not %s" % mode )
196198 self ._mode : str = "rb" if mode is None else mode
197- self ._s3_read_ahead_size : int = s3_read_ahead_size
199+ if s3_block_size < 2 :
200+ raise exceptions .InvalidArgumentValue ("s3_block_size MUST > 1" )
201+ self ._s3_block_size : int = s3_block_size
202+ self ._s3_half_block_size : int = s3_block_size // 2
198203 self ._s3_additional_kwargs : Dict [str , str ] = {} if s3_additional_kwargs is None else s3_additional_kwargs
199204 self ._client : boto3 .client = _utils .client (service_name = "s3" , session = self ._boto3_session )
200205 self ._loc : int = 0
201- self .closed : bool = False
202206
203207 if self .readable () is True :
204208 self ._cache : bytes = b""
@@ -209,6 +213,7 @@ def __init__(
209213 raise exceptions .InvalidArgumentValue (f"S3 object w/o defined size: { path } " )
210214 self ._size : int = size
211215 _logger .debug ("self._size: %s" , self ._size )
216+ _logger .debug ("self._s3_block_size: %s" , self ._s3_block_size )
212217 elif self .writable () is True :
213218 self ._mpu : Dict [str , Any ] = {}
214219 self ._buffer : io .BytesIO = io .BytesIO ()
@@ -289,16 +294,60 @@ def _fetch_range_proxy(self, start: int, end: int) -> bytes:
289294 )
290295
291296 def _fetch (self , start : int , end : int ) -> None :
292- if end > self ._size :
293- end = self ._size
297+ end = self ._size if end > self ._size else end
298+ start = 0 if start < 0 else start
299+
300+ if start >= self ._start and end <= self ._end :
301+ return None # Does not require download
294302
295- if start < self ._start or end > self ._end :
303+ if end - start >= self ._s3_block_size : # Fetching length greater than cache length
304+ self ._cache = self ._fetch_range_proxy (start , end )
296305 self ._start = start
297- if ((end - start ) < self ._s3_read_ahead_size ) and (end < self ._size ):
298- self ._end = start + self ._s3_read_ahead_size
299- else :
300- self ._end = end
301- self ._cache = self ._fetch_range_proxy (self ._start , self ._end )
306+ self ._end = end
307+ return None
308+
309+ # Calculating block START and END positions
310+ _logger .debug ("Downloading: %s (start) / %s (end)" , start , end )
311+ mid : int = int (math .ceil ((start + end ) / 2 ))
312+ new_block_start : int = mid - self ._s3_half_block_size
313+ new_block_end : int = mid + self ._s3_half_block_size
314+ _logger .debug ("new_block_start: %s / new_block_end: %s / mid: %s" , new_block_start , new_block_end , mid )
315+ if new_block_start < 0 and new_block_end > self ._size : # both ends overflowing
316+ new_block_start = 0
317+ new_block_end = self ._size
318+ elif new_block_end > self ._size : # right overflow
319+ new_block_start = new_block_start - (new_block_end - self ._size )
320+ new_block_start = 0 if new_block_start < 0 else new_block_start
321+ new_block_end = self ._size
322+ elif new_block_start < 0 : # left overflow
323+ new_block_end = new_block_end + (0 - new_block_start )
324+ new_block_end = self ._size if new_block_end > self ._size else new_block_end
325+ new_block_start = 0
326+ _logger .debug (
327+ "new_block_start: %s / new_block_end: %s/ self._start: %s / self._end: %s" ,
328+ new_block_start ,
329+ new_block_end ,
330+ self ._start ,
331+ self ._end ,
332+ )
333+
334+ # Calculating missing bytes in cache
335+ if (new_block_start < self ._start and new_block_end > self ._end ) or (
336+ new_block_start > self ._end and new_block_end < self ._start
337+ ): # Full block download
338+ self ._cache = self ._fetch_range_proxy (new_block_start , new_block_end )
339+ elif new_block_end > self ._end :
340+ prune_diff : int = new_block_start - self ._start
341+ self ._cache = self ._cache [prune_diff :] + self ._fetch_range_proxy (self ._end , new_block_end )
342+ elif new_block_start < self ._start :
343+ prune_diff = new_block_end - self ._end
344+ self ._cache = self ._cache [:- prune_diff ] + self ._fetch_range_proxy (new_block_start , self ._start )
345+ else :
346+ raise RuntimeError ("Wrangler's cache calculation error." )
347+ self ._start = new_block_start
348+ self ._end = new_block_end
349+
350+ return None
302351
303352 def read (self , length : int = - 1 ) -> Union [bytes , str ]:
304353 """Return cached data and fetch on demand chunks."""
@@ -313,12 +362,11 @@ def read(self, length: int = -1) -> Union[bytes, str]:
313362 self ._fetch (self ._loc , self ._loc + length )
314363 out : bytes = self ._cache [self ._loc - self ._start : self ._loc - self ._start + length ]
315364 self ._loc += len (out )
316-
317365 return out
318366
319367 def readline (self , length : int = - 1 ) -> Union [bytes , str ]:
320368 """Read until the next line terminator."""
321- self ._fetch (self ._loc , self ._loc + self ._s3_read_ahead_size )
369+ self ._fetch (self ._loc , self ._loc + self ._s3_block_size )
322370 while True :
323371 found : int = self ._cache [self ._loc - self ._start :].find (self ._newline .encode (encoding = self ._encoding ))
324372
@@ -329,7 +377,7 @@ def readline(self, length: int = -1) -> Union[bytes, str]:
329377 if self ._end >= self ._size :
330378 return self .read (length )
331379
332- self ._fetch (self ._loc , self ._end + self ._s3_read_ahead_size )
380+ self ._fetch (self ._loc , self ._end + self ._s3_half_block_size )
333381
334382 def readlines (self ) -> List [Union [bytes , str ]]:
335383 """Return all lines as list."""
@@ -472,7 +520,7 @@ def open_s3_object(
472520 mode : str ,
473521 use_threads : bool = False ,
474522 s3_additional_kwargs : Optional [Dict [str , str ]] = None ,
475- s3_read_ahead_size : int = 4_194_304 , # 4 MB (4 * 2**20)
523+ s3_block_size : int = 4_194_304 , # 4 MB (4 * 2**20)
476524 boto3_session : Optional [boto3 .Session ] = None ,
477525 newline : Optional [str ] = "\n " ,
478526 encoding : Optional [str ] = "utf-8" ,
@@ -483,7 +531,7 @@ def open_s3_object(
483531 try :
484532 s3obj = _S3Object (
485533 path = path ,
486- s3_read_ahead_size = s3_read_ahead_size ,
534+ s3_block_size = s3_block_size ,
487535 mode = mode ,
488536 use_threads = use_threads ,
489537 s3_additional_kwargs = s3_additional_kwargs ,
@@ -494,7 +542,13 @@ def open_s3_object(
494542 if "b" in mode : # binary
495543 yield s3obj
496544 else : # text
497- text_s3obj = io .TextIOWrapper (cast (BinaryIO , s3obj ), encoding = encoding , newline = newline )
545+ text_s3obj = io .TextIOWrapper (
546+ buffer = cast (BinaryIO , s3obj ),
547+ encoding = encoding ,
548+ newline = newline ,
549+ line_buffering = False ,
550+ write_through = False ,
551+ )
498552 yield text_s3obj
499553 finally :
500554 if text_s3obj is not None and text_s3obj .closed is False :
0 commit comments