|
| 1 | +"""Amazon S3 Select Module (PRIVATE).""" |
| 2 | + |
| 3 | +import concurrent.futures |
| 4 | +import itertools |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import pprint |
| 8 | +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
| 9 | + |
| 10 | +import boto3 |
| 11 | +import pandas as pd |
| 12 | + |
| 13 | +from awswrangler import _utils, exceptions |
| 14 | +from awswrangler.s3._describe import size_objects |
| 15 | + |
| 16 | +_logger: logging.Logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | +_RANGE_CHUNK_SIZE: int = int(1024 * 1024) |
| 19 | + |
| 20 | + |
| 21 | +def _gen_scan_range(obj_size: int) -> Iterator[Tuple[int, int]]: |
| 22 | + for i in range(0, obj_size, _RANGE_CHUNK_SIZE): |
| 23 | + yield (i, i + min(_RANGE_CHUNK_SIZE, obj_size - i)) |
| 24 | + |
| 25 | + |
| 26 | +def _select_object_content( |
| 27 | + args: Dict[str, Any], |
| 28 | + client_s3: boto3.Session, |
| 29 | + scan_range: Optional[Tuple[int, int]] = None, |
| 30 | +) -> pd.DataFrame: |
| 31 | + if scan_range: |
| 32 | + response = client_s3.select_object_content(**args, ScanRange={"Start": scan_range[0], "End": scan_range[1]}) |
| 33 | + else: |
| 34 | + response = client_s3.select_object_content(**args) |
| 35 | + |
| 36 | + dfs: List[pd.DataFrame] = [] |
| 37 | + partial_record: str = "" |
| 38 | + for event in response["Payload"]: |
| 39 | + if "Records" in event: |
| 40 | + records = event["Records"]["Payload"].decode(encoding="utf-8", errors="ignore").split("\n") |
| 41 | + records[0] = partial_record + records[0] |
| 42 | + # Record end can either be a partial record or a return char |
| 43 | + partial_record = records.pop() |
| 44 | + dfs.append( |
| 45 | + pd.DataFrame( |
| 46 | + [json.loads(record) for record in records], |
| 47 | + ) |
| 48 | + ) |
| 49 | + if not dfs: |
| 50 | + return pd.DataFrame() |
| 51 | + return pd.concat(dfs, ignore_index=True) |
| 52 | + |
| 53 | + |
| 54 | +def _paginate_stream( |
| 55 | + args: Dict[str, Any], path: str, use_threads: Union[bool, int], boto3_session: Optional[boto3.Session] |
| 56 | +) -> pd.DataFrame: |
| 57 | + obj_size: int = size_objects( # type: ignore |
| 58 | + path=[path], |
| 59 | + use_threads=False, |
| 60 | + boto3_session=boto3_session, |
| 61 | + ).get(path) |
| 62 | + if obj_size is None: |
| 63 | + raise exceptions.InvalidArgumentValue(f"S3 object w/o defined size: {path}") |
| 64 | + |
| 65 | + dfs: List[pd.Dataframe] = [] |
| 66 | + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) |
| 67 | + |
| 68 | + if use_threads is False: |
| 69 | + dfs = list( |
| 70 | + _select_object_content( |
| 71 | + args=args, |
| 72 | + client_s3=client_s3, |
| 73 | + scan_range=scan_range, |
| 74 | + ) |
| 75 | + for scan_range in _gen_scan_range(obj_size=obj_size) |
| 76 | + ) |
| 77 | + else: |
| 78 | + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) |
| 79 | + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: |
| 80 | + dfs = list( |
| 81 | + executor.map( |
| 82 | + _select_object_content, |
| 83 | + itertools.repeat(args), |
| 84 | + itertools.repeat(client_s3), |
| 85 | + _gen_scan_range(obj_size=obj_size), |
| 86 | + ) |
| 87 | + ) |
| 88 | + return pd.concat(dfs, ignore_index=True) |
| 89 | + |
| 90 | + |
| 91 | +def select_query( |
| 92 | + sql: str, |
| 93 | + path: str, |
| 94 | + input_serialization: str, |
| 95 | + input_serialization_params: Dict[str, Union[bool, str]], |
| 96 | + compression: Optional[str] = None, |
| 97 | + use_threads: Union[bool, int] = False, |
| 98 | + boto3_session: Optional[boto3.Session] = None, |
| 99 | + s3_additional_kwargs: Optional[Dict[str, Any]] = None, |
| 100 | +) -> pd.DataFrame: |
| 101 | + r"""Filter contents of an Amazon S3 object based on SQL statement. |
| 102 | +
|
| 103 | + Note: Scan ranges are only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) |
| 104 | + and JSON objects (in LINES mode only). It means scanning cannot be split across threads if the latter |
| 105 | + conditions are not met, leading to lower performance. |
| 106 | +
|
| 107 | + Parameters |
| 108 | + ---------- |
| 109 | + sql: str |
| 110 | + SQL statement used to query the object. |
| 111 | + path: str |
| 112 | + S3 path to the object (e.g. s3://bucket/key). |
| 113 | + input_serialization: str, |
| 114 | + Format of the S3 object queried. |
| 115 | + Valid values: "CSV", "JSON", or "Parquet". Case sensitive. |
| 116 | + input_serialization_params: Dict[str, Union[bool, str]] |
| 117 | + Dictionary describing the serialization of the S3 object. |
| 118 | + compression: Optional[str] |
| 119 | + Compression type of the S3 object. |
| 120 | + Valid values: None, "gzip", or "bzip2". gzip and bzip2 are only valid for CSV and JSON objects. |
| 121 | + use_threads : Union[bool, int] |
| 122 | + True to enable concurrent requests, False to disable multiple threads. |
| 123 | + If enabled os.cpu_count() is used as the max number of threads. |
| 124 | + If integer is provided, specified number is used. |
| 125 | + boto3_session : boto3.Session(), optional |
| 126 | + Boto3 Session. The default boto3 session is used if none is provided. |
| 127 | + s3_additional_kwargs : Optional[Dict[str, Any]] |
| 128 | + Forwarded to botocore requests. |
| 129 | + Valid values: "SSECustomerAlgorithm", "SSECustomerKey", "ExpectedBucketOwner". |
| 130 | + e.g. s3_additional_kwargs={'SSECustomerAlgorithm': 'md5'} |
| 131 | +
|
| 132 | + Returns |
| 133 | + ------- |
| 134 | + pandas.DataFrame |
| 135 | + Pandas DataFrame with results from query. |
| 136 | +
|
| 137 | + Examples |
| 138 | + -------- |
| 139 | + Reading a gzip compressed JSON document |
| 140 | +
|
| 141 | + >>> import awswrangler as wr |
| 142 | + >>> df = wr.s3.select_query( |
| 143 | + ... sql='SELECT * FROM s3object[*][*]', |
| 144 | + ... path='s3://bucket/key.json.gzip', |
| 145 | + ... input_serialization='JSON', |
| 146 | + ... input_serialization_params={ |
| 147 | + ... 'Type': 'Document', |
| 148 | + ... }, |
| 149 | + ... compression="gzip", |
| 150 | + ... ) |
| 151 | +
|
| 152 | + Reading an entire CSV object using threads |
| 153 | +
|
| 154 | + >>> import awswrangler as wr |
| 155 | + >>> df = wr.s3.select_query( |
| 156 | + ... sql='SELECT * FROM s3object', |
| 157 | + ... path='s3://bucket/key.csv', |
| 158 | + ... input_serialization='CSV', |
| 159 | + ... input_serialization_params={ |
| 160 | + ... 'FileHeaderInfo': 'Use', |
| 161 | + ... 'RecordDelimiter': '\r\n' |
| 162 | + ... }, |
| 163 | + ... use_threads=True, |
| 164 | + ... ) |
| 165 | +
|
| 166 | + Reading a single column from Parquet object with pushdown filter |
| 167 | +
|
| 168 | + >>> import awswrangler as wr |
| 169 | + >>> df = wr.s3.select_query( |
| 170 | + ... sql='SELECT s.\"id\" FROM s3object s where s.\"id\" = 1.0', |
| 171 | + ... path='s3://bucket/key.snappy.parquet', |
| 172 | + ... input_serialization='Parquet', |
| 173 | + ... ) |
| 174 | + """ |
| 175 | + if path.endswith("/"): |
| 176 | + raise exceptions.InvalidArgumentValue("<path> argument should be an S3 key, not a prefix.") |
| 177 | + if input_serialization not in ["CSV", "JSON", "Parquet"]: |
| 178 | + raise exceptions.InvalidArgumentValue("<input_serialization> argument must be 'CSV', 'JSON' or 'Parquet'") |
| 179 | + if compression not in [None, "gzip", "bzip2"]: |
| 180 | + raise exceptions.InvalidCompression(f"Invalid {compression} compression, please use None, 'gzip' or 'bzip2'.") |
| 181 | + if compression and (input_serialization not in ["CSV", "JSON"]): |
| 182 | + raise exceptions.InvalidArgumentCombination( |
| 183 | + "'gzip' or 'bzip2' are only valid for input 'CSV' or 'JSON' objects." |
| 184 | + ) |
| 185 | + bucket, key = _utils.parse_path(path) |
| 186 | + |
| 187 | + args: Dict[str, Any] = { |
| 188 | + "Bucket": bucket, |
| 189 | + "Key": key, |
| 190 | + "Expression": sql, |
| 191 | + "ExpressionType": "SQL", |
| 192 | + "RequestProgress": {"Enabled": False}, |
| 193 | + "InputSerialization": { |
| 194 | + input_serialization: input_serialization_params, |
| 195 | + "CompressionType": compression.upper() if compression else "NONE", |
| 196 | + }, |
| 197 | + "OutputSerialization": { |
| 198 | + "JSON": {}, |
| 199 | + }, |
| 200 | + } |
| 201 | + if s3_additional_kwargs: |
| 202 | + args.update(s3_additional_kwargs) |
| 203 | + _logger.debug("args:\n%s", pprint.pformat(args)) |
| 204 | + |
| 205 | + if any( |
| 206 | + [ |
| 207 | + compression, |
| 208 | + input_serialization_params.get("AllowQuotedRecordDelimiter"), |
| 209 | + input_serialization_params.get("Type") == "Document", |
| 210 | + ] |
| 211 | + ): # Scan range is only supported for uncompressed CSV/JSON, CSV (without quoted delimiters) |
| 212 | + # and JSON objects (in LINES mode only) |
| 213 | + _logger.debug("Scan ranges are not supported given provided input.") |
| 214 | + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) |
| 215 | + return _select_object_content(args=args, client_s3=client_s3) |
| 216 | + |
| 217 | + return _paginate_stream(args=args, path=path, use_threads=use_threads, boto3_session=boto3_session) |
0 commit comments