Skip to content

Commit 5b39a17

Browse files
committed
Improving s3 cache strategy.
1 parent c0ae07a commit 5b39a17

File tree

7 files changed

+113
-45
lines changed

7 files changed

+113
-45
lines changed

awswrangler/_config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class _ConfigArg(NamedTuple):
2929
"database": _ConfigArg(dtype=str, nullable=True),
3030
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False),
3131
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
32-
"s3_read_ahead_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
32+
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
3333
}
3434

3535

@@ -206,13 +206,13 @@ def max_cache_seconds(self, value: int) -> None:
206206
self._set_config_value(key="max_cache_seconds", value=value)
207207

208208
@property
209-
def s3_read_ahead_size(self) -> int:
210-
"""Property s3_read_ahead_size."""
211-
return cast(int, self["s3_read_ahead_size"])
209+
def s3_block_size(self) -> int:
210+
"""Property s3_block_size."""
211+
return cast(int, self["s3_block_size"])
212212

213-
@s3_read_ahead_size.setter
214-
def s3_read_ahead_size(self, value: int) -> None:
215-
self._set_config_value(key="s3_read_ahead_size", value=value)
213+
@s3_block_size.setter
214+
def s3_block_size(self, value: int) -> None:
215+
self._set_config_value(key="s3_block_size", value=value)
216216

217217

218218
def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -> str:

awswrangler/s3/_fs.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44
import io
55
import itertools
66
import logging
7+
import math
78
import socket
89
from contextlib import contextmanager
910
from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Set, Tuple, Union, cast
1011

1112
import boto3
12-
from botocore.exceptions import ClientError
13+
from botocore.exceptions import ClientError, ReadTimeoutError
1314

1415
from awswrangler import _utils, exceptions
1516
from awswrangler._config import apply_configs
1617
from awswrangler.s3._describe import size_objects
1718

1819
_logger: logging.Logger = logging.getLogger(__name__)
1920

20-
_S3_RETRYABLE_ERRORS: Tuple[Any, Any] = (socket.timeout, ConnectionError)
21+
_S3_RETRYABLE_ERRORS: Tuple[Any, Any, Any] = (socket.timeout, ConnectionError, ReadTimeoutError)
2122

2223
_MIN_WRITE_BLOCK: int = 5_242_880 # 5 MB (5 * 2**20)
2324
_MIN_PARALLEL_READ_BLOCK: int = 5_242_880 # 5 MB (5 * 2**20)
@@ -178,14 +179,15 @@ class _S3Object: # pylint: disable=too-many-instance-attributes
178179
def __init__(
179180
self,
180181
path: str,
181-
s3_read_ahead_size: int,
182+
s3_block_size: int,
182183
mode: str,
183184
use_threads: bool,
184185
s3_additional_kwargs: Optional[Dict[str, str]],
185186
boto3_session: Optional[boto3.Session],
186187
newline: Optional[str],
187188
encoding: Optional[str],
188189
) -> None:
190+
self.closed: bool = False
189191
self._use_threads = use_threads
190192
self._newline: str = "\n" if newline is None else newline
191193
self._encoding: str = "utf-8" if encoding is None else encoding
@@ -194,11 +196,13 @@ def __init__(
194196
if mode not in {"rb", "wb", "r", "w"}:
195197
raise NotImplementedError("File mode must be {'rb', 'wb', 'r', 'w'}, not %s" % mode)
196198
self._mode: str = "rb" if mode is None else mode
197-
self._s3_read_ahead_size: int = s3_read_ahead_size
199+
if s3_block_size < 2:
200+
raise exceptions.InvalidArgumentValue("s3_block_size MUST > 1")
201+
self._s3_block_size: int = s3_block_size
202+
self._s3_half_block_size: int = s3_block_size // 2
198203
self._s3_additional_kwargs: Dict[str, str] = {} if s3_additional_kwargs is None else s3_additional_kwargs
199204
self._client: boto3.client = _utils.client(service_name="s3", session=self._boto3_session)
200205
self._loc: int = 0
201-
self.closed: bool = False
202206

203207
if self.readable() is True:
204208
self._cache: bytes = b""
@@ -209,6 +213,7 @@ def __init__(
209213
raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}")
210214
self._size: int = size
211215
_logger.debug("self._size: %s", self._size)
216+
_logger.debug("self._s3_block_size: %s", self._s3_block_size)
212217
elif self.writable() is True:
213218
self._mpu: Dict[str, Any] = {}
214219
self._buffer: io.BytesIO = io.BytesIO()
@@ -289,16 +294,60 @@ def _fetch_range_proxy(self, start: int, end: int) -> bytes:
289294
)
290295

291296
def _fetch(self, start: int, end: int) -> None:
292-
if end > self._size:
293-
end = self._size
297+
end = self._size if end > self._size else end
298+
start = 0 if start < 0 else start
299+
300+
if start >= self._start and end <= self._end:
301+
return None # Does not require download
294302

295-
if start < self._start or end > self._end:
303+
if end - start >= self._s3_block_size: # Fetching length greater than cache length
304+
self._cache = self._fetch_range_proxy(start, end)
296305
self._start = start
297-
if ((end - start) < self._s3_read_ahead_size) and (end < self._size):
298-
self._end = start + self._s3_read_ahead_size
299-
else:
300-
self._end = end
301-
self._cache = self._fetch_range_proxy(self._start, self._end)
306+
self._end = end
307+
return None
308+
309+
# Calculating block START and END positions
310+
_logger.debug("Downloading: %s (start) / %s (end)", start, end)
311+
mid: int = int(math.ceil((start + end) / 2))
312+
new_block_start: int = mid - self._s3_half_block_size
313+
new_block_end: int = mid + self._s3_half_block_size
314+
_logger.debug("new_block_start: %s / new_block_end: %s / mid: %s", new_block_start, new_block_end, mid)
315+
if new_block_start < 0 and new_block_end > self._size: # both ends overflowing
316+
new_block_start = 0
317+
new_block_end = self._size
318+
elif new_block_end > self._size: # right overflow
319+
new_block_start = new_block_start - (new_block_end - self._size)
320+
new_block_start = 0 if new_block_start < 0 else new_block_start
321+
new_block_end = self._size
322+
elif new_block_start < 0: # left overflow
323+
new_block_end = new_block_end + (0 - new_block_start)
324+
new_block_end = self._size if new_block_end > self._size else new_block_end
325+
new_block_start = 0
326+
_logger.debug(
327+
"new_block_start: %s / new_block_end: %s/ self._start: %s / self._end: %s",
328+
new_block_start,
329+
new_block_end,
330+
self._start,
331+
self._end,
332+
)
333+
334+
# Calculating missing bytes in cache
335+
if (new_block_start < self._start and new_block_end > self._end) or (
336+
new_block_start > self._end and new_block_end < self._start
337+
): # Full block download
338+
self._cache = self._fetch_range_proxy(new_block_start, new_block_end)
339+
elif new_block_end > self._end:
340+
prune_diff: int = new_block_start - self._start
341+
self._cache = self._cache[prune_diff:] + self._fetch_range_proxy(self._end, new_block_end)
342+
elif new_block_start < self._start:
343+
prune_diff = new_block_end - self._end
344+
self._cache = self._cache[:-prune_diff] + self._fetch_range_proxy(new_block_start, self._start)
345+
else:
346+
raise RuntimeError("Wrangler's cache calculation error.")
347+
self._start = new_block_start
348+
self._end = new_block_end
349+
350+
return None
302351

303352
def read(self, length: int = -1) -> Union[bytes, str]:
304353
"""Return cached data and fetch on demand chunks."""
@@ -313,12 +362,11 @@ def read(self, length: int = -1) -> Union[bytes, str]:
313362
self._fetch(self._loc, self._loc + length)
314363
out: bytes = self._cache[self._loc - self._start : self._loc - self._start + length]
315364
self._loc += len(out)
316-
317365
return out
318366

319367
def readline(self, length: int = -1) -> Union[bytes, str]:
320368
"""Read until the next line terminator."""
321-
self._fetch(self._loc, self._loc + self._s3_read_ahead_size)
369+
self._fetch(self._loc, self._loc + self._s3_block_size)
322370
while True:
323371
found: int = self._cache[self._loc - self._start :].find(self._newline.encode(encoding=self._encoding))
324372

@@ -329,7 +377,7 @@ def readline(self, length: int = -1) -> Union[bytes, str]:
329377
if self._end >= self._size:
330378
return self.read(length)
331379

332-
self._fetch(self._loc, self._end + self._s3_read_ahead_size)
380+
self._fetch(self._loc, self._end + self._s3_half_block_size)
333381

334382
def readlines(self) -> List[Union[bytes, str]]:
335383
"""Return all lines as list."""
@@ -472,7 +520,7 @@ def open_s3_object(
472520
mode: str,
473521
use_threads: bool = False,
474522
s3_additional_kwargs: Optional[Dict[str, str]] = None,
475-
s3_read_ahead_size: int = 4_194_304, # 4 MB (4 * 2**20)
523+
s3_block_size: int = 4_194_304, # 4 MB (4 * 2**20)
476524
boto3_session: Optional[boto3.Session] = None,
477525
newline: Optional[str] = "\n",
478526
encoding: Optional[str] = "utf-8",
@@ -483,7 +531,7 @@ def open_s3_object(
483531
try:
484532
s3obj = _S3Object(
485533
path=path,
486-
s3_read_ahead_size=s3_read_ahead_size,
534+
s3_block_size=s3_block_size,
487535
mode=mode,
488536
use_threads=use_threads,
489537
s3_additional_kwargs=s3_additional_kwargs,
@@ -494,7 +542,13 @@ def open_s3_object(
494542
if "b" in mode: # binary
495543
yield s3obj
496544
else: # text
497-
text_s3obj = io.TextIOWrapper(cast(BinaryIO, s3obj), encoding=encoding, newline=newline)
545+
text_s3obj = io.TextIOWrapper(
546+
buffer=cast(BinaryIO, s3obj),
547+
encoding=encoding,
548+
newline=newline,
549+
line_buffering=False,
550+
write_through=False,
551+
)
498552
yield text_s3obj
499553
finally:
500554
if text_s3obj is not None and text_s3obj.closed is False:

awswrangler/s3/_read_parquet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _read_parquet_metadata_file(
4040
path=path,
4141
mode="rb",
4242
use_threads=use_threads,
43-
s3_read_ahead_size=1_048_576, # 1 MB (1 * 2**20)
43+
s3_block_size=1_048_576, # 1 MB (1 * 2**20)
4444
s3_additional_kwargs=s3_additional_kwargs,
4545
boto3_session=boto3_session,
4646
) as f:
@@ -257,7 +257,7 @@ def _read_parquet_chunked(
257257
path=path,
258258
mode="rb",
259259
use_threads=use_threads,
260-
s3_read_ahead_size=10_485_760, # 10 MB (10 * 2**20)
260+
s3_block_size=10_485_760, # 10 MB (10 * 2**20)
261261
s3_additional_kwargs=s3_additional_kwargs,
262262
boto3_session=boto3_session,
263263
) as f:
@@ -319,7 +319,7 @@ def _read_parquet_file(
319319
path=path,
320320
mode="rb",
321321
use_threads=use_threads,
322-
s3_read_ahead_size=134_217_728, # 128 MB (128 * 2**20)
322+
s3_block_size=134_217_728, # 128 MB (128 * 2**20)
323323
s3_additional_kwargs=s3_additional_kwargs,
324324
boto3_session=boto3_session,
325325
) as f:
@@ -339,7 +339,7 @@ def _count_row_groups(
339339
path=path,
340340
mode="rb",
341341
use_threads=use_threads,
342-
s3_read_ahead_size=1_048_576, # 1 MB (1 * 2**20)
342+
s3_block_size=1_048_576, # 1 MB (1 * 2**20)
343343
s3_additional_kwargs=s3_additional_kwargs,
344344
boto3_session=boto3_session,
345345
) as f:
@@ -361,7 +361,7 @@ def _read_parquet_row_group(
361361
path=path,
362362
mode="rb",
363363
use_threads=use_threads,
364-
s3_read_ahead_size=10_485_760, # 10 MB (10 * 2**20)
364+
s3_block_size=10_485_760, # 10 MB (10 * 2**20)
365365
s3_additional_kwargs=s3_additional_kwargs,
366366
boto3_session=boto3_session,
367367
) as f:

awswrangler/s3/_read_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _read_text_chunked(
5151
with open_s3_object(
5252
path=path,
5353
mode=mode,
54-
s3_read_ahead_size=10_485_760, # 10 MB (10 * 2**20)
54+
s3_block_size=10_485_760, # 10 MB (10 * 2**20)
5555
encoding=encoding,
5656
use_threads=use_threads,
5757
s3_additional_kwargs=s3_additional_kwargs,
@@ -78,7 +78,7 @@ def _read_text_file(
7878
path=path,
7979
mode=mode,
8080
use_threads=use_threads,
81-
s3_read_ahead_size=134_217_728, # 128 MB (128 * 2**20)
81+
s3_block_size=134_217_728, # 128 MB (128 * 2**20)
8282
encoding=encoding,
8383
s3_additional_kwargs=s3_additional_kwargs,
8484
newline=newline,

pytest.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ log_cli=False
33
filterwarnings =
44
ignore::DeprecationWarning
55
addopts =
6-
--log-cli-format "[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s"
6+
--log-cli-format "[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s][%(thread)d] %(message)s"
77
--verbose
88
--capture=sys

tests/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def test_basics(path, glue_database, glue_table):
2424

2525
# Testing configured s3 block size
2626
size = 1 * 2 ** 20 # 1 MB
27-
wr.config.s3_read_ahead_size = size
27+
wr.config.s3_block_size = size
2828
with open_s3_object(path, mode="wb") as s3obj:
2929
s3obj.write(b"foo")
3030
with open_s3_object(path, mode="rb") as s3obj:
31-
assert s3obj._s3_read_ahead_size == size
31+
assert s3obj._s3_block_size == size
3232

3333
# Resetting all configs
3434
wr.config.reset()

tests/test_fs.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_read_full(path, mode, use_threads):
3434
bucket, key = wr._utils.parse_path(path)
3535
text = "AHDG*AWY&GD*A&WGd*AWgd87AGWD*GA*G*g*AGˆˆ&ÂDTW&ˆˆD&ÂTW7ˆˆTAWˆˆDAW&ˆˆAWGDIUHWOD#N"
3636
client_s3.put_object(Body=text, Bucket=bucket, Key=key)
37-
with open_s3_object(path, mode=mode, s3_read_ahead_size=100, newline="\n", use_threads=use_threads) as s3obj:
37+
with open_s3_object(path, mode=mode, s3_block_size=100, newline="\n", use_threads=use_threads) as s3obj:
3838
if mode == "r":
3939
assert s3obj.read() == text
4040
else:
@@ -52,7 +52,7 @@ def test_read_chunked(path, mode, block_size, use_threads):
5252
bucket, key = wr._utils.parse_path(path)
5353
text = "0123456789"
5454
client_s3.put_object(Body=text, Bucket=bucket, Key=key)
55-
with open_s3_object(path, mode=mode, s3_read_ahead_size=block_size, newline="\n", use_threads=use_threads) as s3obj:
55+
with open_s3_object(path, mode=mode, s3_block_size=block_size, newline="\n", use_threads=use_threads) as s3obj:
5656
if mode == "r":
5757
for i in range(3):
5858
assert s3obj.read(1) == text[i]
@@ -67,22 +67,23 @@ def test_read_chunked(path, mode, block_size, use_threads):
6767

6868
@pytest.mark.parametrize("use_threads", [True, False])
6969
@pytest.mark.parametrize("mode", ["r", "rb"])
70-
@pytest.mark.parametrize("block_size", [1, 2, 3, 10, 23, 48, 65, 100])
70+
@pytest.mark.parametrize("block_size", [2, 3, 10, 23, 48, 65, 100])
7171
def test_read_line(path, mode, block_size, use_threads):
7272
client_s3 = boto3.client("s3")
7373
path = f"{path}0.txt"
7474
bucket, key = wr._utils.parse_path(path)
7575
text = "0\n11\n22222\n33333333333333\n44444444444444444444444444444444444444444444\n55555"
7676
expected = ["0\n", "11\n", "22222\n", "33333333333333\n", "44444444444444444444444444444444444444444444\n", "55555"]
7777
client_s3.put_object(Body=text, Bucket=bucket, Key=key)
78-
with open_s3_object(path, mode=mode, s3_read_ahead_size=block_size, newline="\n", use_threads=use_threads) as s3obj:
78+
with open_s3_object(path, mode=mode, s3_block_size=block_size, newline="\n", use_threads=use_threads) as s3obj:
7979
for i, line in enumerate(s3obj):
8080
if mode == "r":
8181
assert line == expected[i]
8282
else:
8383
assert line == expected[i].encode("utf-8")
8484
s3obj.seek(0)
8585
lines = s3obj.readlines()
86+
print(lines)
8687
if mode == "r":
8788
assert lines == expected
8889
else:
@@ -136,11 +137,7 @@ def test_additional_kwargs(path, kms_key_id, s3_additional_kwargs, use_threads):
136137
with open_s3_object(path, mode="w", s3_additional_kwargs=s3_additional_kwargs, use_threads=use_threads) as s3obj:
137138
s3obj.write("foo")
138139
with open_s3_object(
139-
path,
140-
mode="r",
141-
s3_read_ahead_size=10_000_000,
142-
s3_additional_kwargs=s3_additional_kwargs,
143-
use_threads=use_threads,
140+
path, mode="r", s3_block_size=10_000_000, s3_additional_kwargs=s3_additional_kwargs, use_threads=use_threads,
144141
) as s3obj:
145142
assert s3obj.read() == "foo"
146143
desc = wr.s3.describe_objects([path])[path]
@@ -160,3 +157,20 @@ def test_pyarrow(path, glue_table, glue_database):
160157
ensure_data_types(df2, has_list=True)
161158
assert df2.shape == (3, 19)
162159
assert df.iint8.sum() == df2.iint8.sum()
160+
161+
162+
@pytest.mark.parametrize("use_threads", [True, False])
163+
@pytest.mark.parametrize("block_size", [2, 3, 5, 8, 9, 15])
164+
@pytest.mark.parametrize("text", ["012345678", "0123456789"])
165+
def test_cache(path, use_threads, block_size, text):
166+
client_s3 = boto3.client("s3")
167+
path = f"{path}0.txt"
168+
bucket, key = wr._utils.parse_path(path)
169+
client_s3.put_object(Body=text, Bucket=bucket, Key=key)
170+
with open_s3_object(path, mode="rb", s3_block_size=block_size, use_threads=use_threads) as s3obj:
171+
for i in range(len(text)):
172+
value = s3obj.read(1)
173+
print(value)
174+
assert value == text[i].encode("utf-8")
175+
assert len(s3obj._cache) in (block_size, block_size - 1, len(text))
176+
assert s3obj._cache == b""

0 commit comments

Comments
 (0)