Skip to content

Commit 790f8cc

Browse files
authored
S3 Select (#722)
1 parent 54b8c4f commit 790f8cc

File tree

3 files changed

+380
-0
lines changed

3 files changed

+380
-0
lines changed

awswrangler/s3/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from awswrangler.s3._read_excel import read_excel # noqa
1010
from awswrangler.s3._read_parquet import read_parquet, read_parquet_metadata, read_parquet_table # noqa
1111
from awswrangler.s3._read_text import read_csv, read_fwf, read_json # noqa
12+
from awswrangler.s3._select import select_query
1213
from awswrangler.s3._upload import upload # noqa
1314
from awswrangler.s3._wait import wait_objects_exist, wait_objects_not_exist # noqa
1415
from awswrangler.s3._write_excel import to_excel # noqa
@@ -33,6 +34,7 @@
3334
"read_json",
3435
"wait_objects_exist",
3536
"wait_objects_not_exist",
37+
"select_query",
3638
"store_parquet_metadata",
3739
"to_parquet",
3840
"to_csv",

awswrangler/s3/_select.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Amazon S3 Select Module (PRIVATE)."""
2+
3+
import concurrent.futures
4+
import itertools
5+
import json
6+
import logging
7+
import pprint
8+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
9+
10+
import boto3
11+
import pandas as pd
12+
13+
from awswrangler import _utils, exceptions
14+
from awswrangler.s3._describe import size_objects
15+
16+
_logger: logging.Logger = logging.getLogger(__name__)
17+
18+
_RANGE_CHUNK_SIZE: int = int(1024 * 1024)
19+
20+
21+
def _gen_scan_range(obj_size: int) -> Iterator[Tuple[int, int]]:
22+
for i in range(0, obj_size, _RANGE_CHUNK_SIZE):
23+
yield (i, i + min(_RANGE_CHUNK_SIZE, obj_size - i))
24+
25+
26+
def _select_object_content(
27+
args: Dict[str, Any],
28+
client_s3: boto3.Session,
29+
scan_range: Optional[Tuple[int, int]] = None,
30+
) -> pd.DataFrame:
31+
if scan_range:
32+
response = client_s3.select_object_content(**args, ScanRange={"Start": scan_range[0], "End": scan_range[1]})
33+
else:
34+
response = client_s3.select_object_content(**args)
35+
36+
dfs: List[pd.DataFrame] = []
37+
partial_record: str = ""
38+
for event in response["Payload"]:
39+
if "Records" in event:
40+
records = event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore").split("\n")
41+
records[0] = partial_record + records[0]
42+
# Record end can either be a partial record or a return char
43+
partial_record = records.pop()
44+
dfs.append(
45+
pd.DataFrame(
46+
[json.loads(record) for record in records],
47+
)
48+
)
49+
if not dfs:
50+
return pd.DataFrame()
51+
return pd.concat(dfs, ignore_index=True)
52+
53+
54+
def _paginate_stream(
55+
args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session]
56+
) -> pd.DataFrame:
57+
obj_size: int = size_objects( # type: ignore
58+
path=[path],
59+
use_threads=False,
60+
boto3_session=boto3_session,
61+
).get(path)
62+
if obj_size is None:
63+
raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}")
64+
65+
dfs: List[pd.Dataframe] = []
66+
client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
67+
68+
if use_threads is False:
69+
dfs = list(
70+
_select_object_content(
71+
args=args,
72+
client_s3=client_s3,
73+
scan_range=scan_range,
74+
)
75+
for scan_range in _gen_scan_range(obj_size=obj_size)
76+
)
77+
else:
78+
cpus: int = _utils.ensure_cpu_count(use_threads=use_threads)
79+
with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor:
80+
dfs = list(
81+
executor.map(
82+
_select_object_content,
83+
itertools.repeat(args),
84+
itertools.repeat(client_s3),
85+
_gen_scan_range(obj_size=obj_size),
86+
)
87+
)
88+
return pd.concat(dfs, ignore_index=True)
89+
90+
91+
def select_query(
92+
sql: str,
93+
path: str,
94+
input_serialization: str,
95+
input_serialization_params: Dict[str, Union[bool, str]],
96+
compression: Optional[str] = None,
97+
use_threads: Union[bool, int] = False,
98+
boto3_session: Optional[boto3.Session] = None,
99+
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
100+
) -> pd.DataFrame:
101+
r"""Filter contents of an Amazon S3 object based on SQL statement.
102+
103+
Note: Scan ranges are only supported for uncompressed CSV/JSON, CSV (without quoted delimiters)
104+
and JSON objects (in LINES mode only). It means scanning cannot be split across threads if the latter
105+
conditions are not met, leading to lower performance.
106+
107+
Parameters
108+
----------
109+
sql: str
110+
SQL statement used to query the object.
111+
path: str
112+
S3 path to the object (e.g. s3://bucket/key).
113+
input_serialization: str,
114+
Format of the S3 object queried.
115+
Valid values: "CSV", "JSON", or "Parquet". Case sensitive.
116+
input_serialization_params: Dict[str, Union[bool, str]]
117+
Dictionary describing the serialization of the S3 object.
118+
compression: Optional[str]
119+
Compression type of the S3 object.
120+
Valid values: None, "gzip", or "bzip2". gzip and bzip2 are only valid for CSV and JSON objects.
121+
use_threads : Union[bool, int]
122+
True to enable concurrent requests, False to disable multiple threads.
123+
If enabled os.cpu_count() is used as the max number of threads.
124+
If integer is provided, specified number is used.
125+
boto3_session : boto3.Session(), optional
126+
Boto3 Session. The default boto3 session is used if none is provided.
127+
s3_additional_kwargs : Optional[Dict[str, Any]]
128+
Forwarded to botocore requests.
129+
Valid values: "SSECustomerAlgorithm", "SSECustomerKey", "ExpectedBucketOwner".
130+
e.g. s3_additional_kwargs={'SSECustomerAlgorithm': 'md5'}
131+
132+
Returns
133+
-------
134+
pandas.DataFrame
135+
Pandas DataFrame with results from query.
136+
137+
Examples
138+
--------
139+
Reading a gzip compressed JSON document
140+
141+
>>> import awswrangler as wr
142+
>>> df = wr.s3.select_query(
143+
... sql='SELECT * FROM s3object[*][*]',
144+
... path='s3://bucket/key.json.gzip',
145+
... input_serialization='JSON',
146+
... input_serialization_params={
147+
... 'Type': 'Document',
148+
... },
149+
... compression="gzip",
150+
... )
151+
152+
Reading an entire CSV object using threads
153+
154+
>>> import awswrangler as wr
155+
>>> df = wr.s3.select_query(
156+
... sql='SELECT * FROM s3object',
157+
... path='s3://bucket/key.csv',
158+
... input_serialization='CSV',
159+
... input_serialization_params={
160+
... 'FileHeaderInfo': 'Use',
161+
... 'RecordDelimiter': '\r\n'
162+
... },
163+
... use_threads=True,
164+
... )
165+
166+
Reading a single column from Parquet object with pushdown filter
167+
168+
>>> import awswrangler as wr
169+
>>> df = wr.s3.select_query(
170+
... sql='SELECT s.\"id\" FROM s3object s where s.\"id\" = 1.0',
171+
... path='s3://bucket/key.snappy.parquet',
172+
... input_serialization='Parquet',
173+
... )
174+
"""
175+
if path.endswith("/"):
176+
raise exceptions.InvalidArgumentValue("<path> argument should be an S3 key, not a prefix.")
177+
if input_serialization not in ["CSV", "JSON", "Parquet"]:
178+
raise exceptions.InvalidArgumentValue("<input_serialization> argument must be 'CSV', 'JSON' or 'Parquet'")
179+
if compression not in [None, "gzip", "bzip2"]:
180+
raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.")
181+
if compression and (input_serialization not in ["CSV", "JSON"]):
182+
raise exceptions.InvalidArgumentCombination(
183+
"'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects."
184+
)
185+
bucket, key = _utils.parse_path(path)
186+
187+
args: Dict[str, Any] = {
188+
"Bucket": bucket,
189+
"Key": key,
190+
"Expression": sql,
191+
"ExpressionType": "SQL",
192+
"RequestProgress": {"Enabled": False},
193+
"InputSerialization": {
194+
input_serialization: input_serialization_params,
195+
"CompressionType": compression.upper() if compression else "NONE",
196+
},
197+
"OutputSerialization": {
198+
"JSON": {},
199+
},
200+
}
201+
if s3_additional_kwargs:
202+
args.update(s3_additional_kwargs)
203+
_logger.debug("args:\n%s", pprint.pformat(args))
204+
205+
if any(
206+
[
207+
compression,
208+
input_serialization_params.get("AllowQuotedRecordDelimiter"),
209+
input_serialization_params.get("Type") == "Document",
210+
]
211+
): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters)
212+
# and JSON objects (in LINES mode only)
213+
_logger.debug("Scan ranges are not supported given provided input.")
214+
client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session)
215+
return _select_object_content(args=args, client_s3=client_s3)
216+
217+
return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session)

tests/test_s3_select.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import logging
2+
3+
import pandas as pd
4+
import pytest
5+
6+
import awswrangler as wr
7+
8+
logging.getLogger("awswrangler").setLevel(logging.DEBUG)
9+
10+
11+
@pytest.mark.parametrize("use_threads", [True, False, 2])
12+
def test_full_table(path, use_threads):
13+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]})
14+
15+
# Parquet
16+
file_path = f"{path}test_parquet_file.snappy.parquet"
17+
wr.s3.to_parquet(df, file_path, compression="snappy")
18+
df2 = wr.s3.select_query(
19+
sql="select * from s3object",
20+
path=file_path,
21+
input_serialization="Parquet",
22+
input_serialization_params={},
23+
use_threads=use_threads,
24+
)
25+
assert df.equals(df2)
26+
27+
# CSV
28+
file_path = f"{path}test_csv_file.csv"
29+
wr.s3.to_csv(df, file_path, index=False)
30+
df3 = wr.s3.select_query(
31+
sql="select * from s3object",
32+
path=file_path,
33+
input_serialization="CSV",
34+
input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"},
35+
use_threads=use_threads,
36+
)
37+
assert len(df.index) == len(df3.index)
38+
assert list(df.columns) == list(df3.columns)
39+
assert df.shape == df3.shape
40+
41+
# JSON
42+
file_path = f"{path}test_json_file.json"
43+
wr.s3.to_json(df, file_path, orient="records")
44+
df4 = wr.s3.select_query(
45+
sql="select * from s3object[*][*]",
46+
path=file_path,
47+
input_serialization="JSON",
48+
input_serialization_params={"Type": "Document"},
49+
use_threads=use_threads,
50+
)
51+
assert df.equals(df4)
52+
53+
54+
@pytest.mark.parametrize("use_threads", [True, False, 2])
55+
def test_push_down(path, use_threads):
56+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]})
57+
58+
file_path = f"{path}test_parquet_file.snappy.parquet"
59+
wr.s3.to_parquet(df, file_path, compression="snappy")
60+
df2 = wr.s3.select_query(
61+
sql='select * from s3object s where s."c0" = 1',
62+
path=file_path,
63+
input_serialization="Parquet",
64+
input_serialization_params={},
65+
use_threads=use_threads,
66+
)
67+
assert df2.shape == (1, 3)
68+
assert df2.c0.sum() == 1
69+
70+
file_path = f"{path}test_parquet_file.gzip.parquet"
71+
wr.s3.to_parquet(df, file_path, compression="gzip")
72+
df2 = wr.s3.select_query(
73+
sql='select * from s3object s where s."c0" = 99',
74+
path=file_path,
75+
input_serialization="Parquet",
76+
input_serialization_params={},
77+
use_threads=use_threads,
78+
)
79+
assert df2.shape == (0, 0)
80+
81+
file_path = f"{path}test_csv_file.csv"
82+
wr.s3.to_csv(df, file_path, header=False, index=False)
83+
df3 = wr.s3.select_query(
84+
sql='select s."_1" from s3object s limit 2',
85+
path=file_path,
86+
input_serialization="CSV",
87+
input_serialization_params={"FileHeaderInfo": "None", "RecordDelimiter": "\n"},
88+
use_threads=use_threads,
89+
)
90+
assert df3.shape == (2, 1)
91+
92+
file_path = f"{path}test_json_file.json"
93+
wr.s3.to_json(df, file_path, orient="records")
94+
df4 = wr.s3.select_query(
95+
sql="select count(*) from s3object[*][*]",
96+
path=file_path,
97+
input_serialization="JSON",
98+
input_serialization_params={"Type": "Document"},
99+
use_threads=use_threads,
100+
)
101+
assert df4.shape == (1, 1)
102+
assert df4._1.sum() == 3
103+
104+
105+
@pytest.mark.parametrize("compression", ["gzip", "bz2"])
106+
def test_compression(path, compression):
107+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]})
108+
109+
# CSV
110+
file_path = f"{path}test_csv_file.csv"
111+
wr.s3.to_csv(df, file_path, index=False, compression=compression)
112+
df2 = wr.s3.select_query(
113+
sql="select * from s3object",
114+
path=file_path,
115+
input_serialization="CSV",
116+
input_serialization_params={"FileHeaderInfo": "Use", "RecordDelimiter": "\n"},
117+
compression="bzip2" if compression == "bz2" else compression,
118+
use_threads=False,
119+
)
120+
assert len(df.index) == len(df2.index)
121+
assert list(df.columns) == list(df2.columns)
122+
assert df.shape == df2.shape
123+
124+
# JSON
125+
file_path = f"{path}test_json_file.json"
126+
wr.s3.to_json(df, file_path, orient="records", compression=compression)
127+
df3 = wr.s3.select_query(
128+
sql="select * from s3object[*][*]",
129+
path=file_path,
130+
input_serialization="JSON",
131+
input_serialization_params={"Type": "Document"},
132+
compression="bzip2" if compression == "bz2" else compression,
133+
use_threads=False,
134+
)
135+
assert df.equals(df3)
136+
137+
138+
@pytest.mark.parametrize(
139+
"s3_additional_kwargs",
140+
[None, {"ServerSideEncryption": "AES256"}, {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": None}],
141+
)
142+
def test_encryption(path, kms_key_id, s3_additional_kwargs):
143+
if s3_additional_kwargs is not None and "SSEKMSKeyId" in s3_additional_kwargs:
144+
s3_additional_kwargs["SSEKMSKeyId"] = kms_key_id
145+
146+
df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"], "c2": [4.0, 5.0, 6.0]})
147+
file_path = f"{path}test_parquet_file.snappy.parquet"
148+
wr.s3.to_parquet(
149+
df,
150+
file_path,
151+
compression="snappy",
152+
s3_additional_kwargs=s3_additional_kwargs,
153+
)
154+
df2 = wr.s3.select_query(
155+
sql="select * from s3object",
156+
path=file_path,
157+
input_serialization="Parquet",
158+
input_serialization_params={},
159+
use_threads=False,
160+
)
161+
assert df.equals(df2)

0 commit comments

Comments
 (0)