|
6 | 6 | from pyspark.sql.functions import floor, rand |
7 | 7 | from pyspark.sql.types import TimestampType |
8 | 8 |
|
9 | | -from awswrangler.exceptions import MissingBatchDetected |
| 9 | +from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat |
10 | 10 |
|
11 | 11 | logger = logging.getLogger(__name__) |
12 | 12 |
|
@@ -142,3 +142,70 @@ def write(pandas_dataframe): |
142 | 142 | ) |
143 | 143 | dataframe.unpersist() |
144 | 144 | self._session.s3.delete_objects(path=path) |
| 145 | + |
| 146 | + def create_glue_table(self, |
| 147 | + database, |
| 148 | + path, |
| 149 | + dataframe, |
| 150 | + file_format, |
| 151 | + compression, |
| 152 | + table=None, |
| 153 | + serde=None, |
| 154 | + sep=",", |
| 155 | + partition_by=None, |
| 156 | + load_partitions=True, |
| 157 | + replace_if_exists=True): |
| 158 | + """ |
| 159 | + Create a Glue metadata table pointing for some dataset stored on AWS S3. |
| 160 | +
|
| 161 | + :param dataframe: PySpark Dataframe |
| 162 | + :param file_format: File format (E.g. "parquet", "csv") |
| 163 | + :param partition_by: Columns used for partitioning |
| 164 | + :param path: AWS S3 path |
| 165 | + :param compression: Compression (e.g. gzip, snappy, lzo, etc) |
| 166 | + :param sep: Separator token for CSV formats (e.g. ",", ";", "|") |
| 167 | + :param serde: Serializer/Deserializer (e.g. "OpenCSVSerDe", "LazySimpleSerDe") |
| 168 | + :param database: Glue database name |
| 169 | + :param table: Glue table name. If not passed, extracted from the path |
| 170 | + :param load_partitions: Load partitions after the table creation |
| 171 | + :param replace_if_exists: Drop table and recreates that if already exists |
| 172 | + :return: None |
| 173 | + """ |
| 174 | + file_format = file_format.lower() |
| 175 | + if file_format not in ["parquet", "csv"]: |
| 176 | + raise UnsupportedFileFormat(file_format) |
| 177 | + table = table if table else self._session.glue.parse_table_name(path) |
| 178 | + table = table.lower().replace(".", "_") |
| 179 | + logger.debug(f"table: {table}") |
| 180 | + full_schema = dataframe.dtypes |
| 181 | + if partition_by is None: |
| 182 | + partition_by = [] |
| 183 | + schema = [x for x in full_schema if x[0] not in partition_by] |
| 184 | + partitions_schema_tmp = { |
| 185 | + x[0]: x[1] |
| 186 | + for x in full_schema if x[0] in partition_by |
| 187 | + } |
| 188 | + partitions_schema = [(x, partitions_schema_tmp[x]) |
| 189 | + for x in partition_by] |
| 190 | + logger.debug(f"schema: {schema}") |
| 191 | + logger.debug(f"partitions_schema: {partitions_schema}") |
| 192 | + if replace_if_exists is not None: |
| 193 | + self._session.glue.delete_table_if_exists(database=database, |
| 194 | + table=table) |
| 195 | + extra_args = {} |
| 196 | + if file_format == "csv": |
| 197 | + extra_args["sep"] = sep |
| 198 | + if serde is None: |
| 199 | + serde = "OpenCSVSerDe" |
| 200 | + extra_args["serde"] = serde |
| 201 | + self._session.glue.create_table( |
| 202 | + database=database, |
| 203 | + table=table, |
| 204 | + schema=schema, |
| 205 | + partition_cols_schema=partitions_schema, |
| 206 | + path=path, |
| 207 | + file_format=file_format, |
| 208 | + compression=compression, |
| 209 | + extra_args=extra_args) |
| 210 | + if load_partitions: |
| 211 | + self._session.athena.repair_table(database=database, table=table) |
0 commit comments