|
1 | | -from typing import Dict, List, Union, Optional |
| 1 | +from typing import Dict, List, Union, Optional, Any |
2 | 2 | import json |
3 | 3 | import logging |
4 | 4 |
|
@@ -346,3 +346,49 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c |
346 | 346 | else: |
347 | 347 | raise InvalidDataframeType(dataframe_type) |
348 | 348 | return schema_built |
| 349 | + |
| 350 | + @staticmethod |
| 351 | + def to_parquet(sql: str, |
| 352 | + path: str, |
| 353 | + iam_role: str, |
| 354 | + redshift_conn: Any, |
| 355 | + partition_cols: Optional[List] = None) -> List[str]: |
| 356 | + """ |
| 357 | + Write a query result as parquet files on S3 |
| 358 | +
|
| 359 | + :param sql: SQL Query |
| 360 | + :param path: AWS S3 path to write the data (e.g. s3://...) |
| 361 | + :param iam_role: AWS IAM role with the related permissions |
| 362 | + :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) |
| 363 | + :param partition_cols: Specifies the partition keys for the unload operation. |
| 364 | + """ |
| 365 | + sql = sql.replace("'", "\'").replace(";", "") # escaping single quote |
| 366 | + path = path if path[-1] == "/" else path + "/" |
| 367 | + cursor: Any = redshift_conn.cursor() |
| 368 | + partition_str: str = "" |
| 369 | + if partition_cols is not None: |
| 370 | + partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n" |
| 371 | + query: str = f"-- AWS DATA WRANGLER\n" \ |
| 372 | + f"UNLOAD ('{sql}')\n" \ |
| 373 | + f"TO '{path}'\n" \ |
| 374 | + f"IAM_ROLE '{iam_role}'\n" \ |
| 375 | + f"ALLOWOVERWRITE\n" \ |
| 376 | + f"PARALLEL ON\n" \ |
| 377 | + f"ENCRYPTED \n" \ |
| 378 | + f"{partition_str}" \ |
| 379 | + f"FORMAT PARQUET;" |
| 380 | + logger.debug(f"query:\n{query}") |
| 381 | + cursor.execute(query) |
| 382 | + query = "-- AWS DATA WRANGLER\nSELECT pg_last_query_id() AS query_id" |
| 383 | + logger.debug(f"query:\n{query}") |
| 384 | + cursor.execute(query) |
| 385 | + query_id = cursor.fetchall()[0][0] |
| 386 | + query = f"-- AWS DATA WRANGLER\n" \ |
| 387 | + f"SELECT path FROM STL_UNLOAD_LOG WHERE query={query_id};" |
| 388 | + logger.debug(f"query:\n{query}") |
| 389 | + cursor.execute(query) |
| 390 | + paths: List[str] = [row[0].replace(" ", "") for row in cursor.fetchall()] |
| 391 | + logger.debug(f"paths: {paths}") |
| 392 | + redshift_conn.commit() |
| 393 | + cursor.close() |
| 394 | + return paths |
0 commit comments