Skip to content

Commit a4e00dc

Browse files
committed
Add arguments to set VARCHAR lengths for Redshift and Aurora tables
1 parent 99f5e29 commit a4e00dc

File tree

6 files changed

+167
-123
lines changed

6 files changed

+167
-123
lines changed

awswrangler/aurora.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def load_table(dataframe: pd.DataFrame,
141141
mode: str = "append",
142142
preserve_index: bool = False,
143143
engine: str = "mysql",
144-
region: str = "us-east-1"):
144+
region: str = "us-east-1",
145+
varchar_default_length: int = 256,
146+
varchar_lengths: Optional[Dict[str, int]] = None) -> None:
145147
"""
146148
Load text/CSV files into a Aurora table using a manifest file.
147149
Creates the table if necessary.
@@ -158,6 +160,8 @@ def load_table(dataframe: pd.DataFrame,
158160
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
159161
:param engine: "mysql" or "postgres"
160162
:param region: AWS S3 bucket region (Required only for postgres engine)
163+
:param varchar_default_length: The size that will be set for all VARCHAR columns not specified with varchar_lengths
164+
:param varchar_lengths: Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200})
161165
:return: None
162166
"""
163167
if "postgres" in engine.lower():
@@ -170,7 +174,9 @@ def load_table(dataframe: pd.DataFrame,
170174
mode=mode,
171175
preserve_index=preserve_index,
172176
region=region,
173-
columns=columns)
177+
columns=columns,
178+
varchar_default_length=varchar_default_length,
179+
varchar_lengths=varchar_lengths)
174180
elif "mysql" in engine.lower():
175181
Aurora.load_table_mysql(dataframe=dataframe,
176182
dataframe_type=dataframe_type,
@@ -181,7 +187,9 @@ def load_table(dataframe: pd.DataFrame,
181187
mode=mode,
182188
preserve_index=preserve_index,
183189
num_files=num_files,
184-
columns=columns)
190+
columns=columns,
191+
varchar_default_length=varchar_default_length,
192+
varchar_lengths=varchar_lengths)
185193
else:
186194
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
187195

@@ -195,7 +203,9 @@ def load_table_postgres(dataframe: pd.DataFrame,
195203
mode: str = "append",
196204
preserve_index: bool = False,
197205
region: str = "us-east-1",
198-
columns: Optional[List[str]] = None):
206+
columns: Optional[List[str]] = None,
207+
varchar_default_length: int = 256,
208+
varchar_lengths: Optional[Dict[str, int]] = None):
199209
"""
200210
Load text/CSV files into a Aurora table using a manifest file.
201211
Creates the table if necessary.
@@ -210,6 +220,8 @@ def load_table_postgres(dataframe: pd.DataFrame,
210220
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
211221
:param region: AWS S3 bucket region (Required only for postgres engine)
212222
:param columns: List of columns to load
223+
:param varchar_default_length: The size that will be set for all VARCHAR columns not specified with varchar_lengths
224+
:param varchar_lengths: Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200})
213225
:return: None
214226
"""
215227
with connection.cursor() as cursor:
@@ -221,7 +233,9 @@ def load_table_postgres(dataframe: pd.DataFrame,
221233
table_name=table_name,
222234
preserve_index=preserve_index,
223235
engine="postgres",
224-
columns=columns)
236+
columns=columns,
237+
varchar_default_length=varchar_default_length,
238+
varchar_lengths=varchar_lengths)
225239
connection.commit()
226240
logger.debug("CREATE TABLE committed.")
227241
for path in load_paths:
@@ -266,7 +280,9 @@ def load_table_mysql(dataframe: pd.DataFrame,
266280
num_files: int,
267281
mode: str = "append",
268282
preserve_index: bool = False,
269-
columns: Optional[List[str]] = None):
283+
columns: Optional[List[str]] = None,
284+
varchar_default_length: int = 256,
285+
varchar_lengths: Optional[Dict[str, int]] = None):
270286
"""
271287
Load text/CSV files into a Aurora table using a manifest file.
272288
Creates the table if necessary.
@@ -281,6 +297,8 @@ def load_table_mysql(dataframe: pd.DataFrame,
281297
:param mode: append or overwrite
282298
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
283299
:param columns: List of columns to load
300+
:param varchar_default_length: The size that will be set for all VARCHAR columns not specified with varchar_lengths
301+
:param varchar_lengths: Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200})
284302
:return: None
285303
"""
286304
with connection.cursor() as cursor:
@@ -292,7 +310,9 @@ def load_table_mysql(dataframe: pd.DataFrame,
292310
table_name=table_name,
293311
preserve_index=preserve_index,
294312
engine="mysql",
295-
columns=columns)
313+
columns=columns,
314+
varchar_default_length=varchar_default_length,
315+
varchar_lengths=varchar_lengths)
296316
sql = Aurora._get_load_sql(path=manifest_path,
297317
schema_name=schema_name,
298318
table_name=table_name,
@@ -368,7 +388,9 @@ def _create_table(cursor,
368388
table_name,
369389
preserve_index=False,
370390
engine: str = "mysql",
371-
columns: Optional[List[str]] = None):
391+
columns: Optional[List[str]] = None,
392+
varchar_default_length: int = 256,
393+
varchar_lengths: Optional[Dict[str, int]] = None) -> None:
372394
"""
373395
Creates Aurora table.
374396
@@ -380,6 +402,8 @@ def _create_table(cursor,
380402
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
381403
:param engine: "mysql" or "postgres"
382404
:param columns: List of columns to load
405+
:param varchar_default_length: The size that will be set for all VARCHAR columns not specified with varchar_lengths
406+
:param varchar_lengths: Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200})
383407
:return: None
384408
"""
385409
sql: str = f"-- AWS DATA WRANGLER\n" \
@@ -397,7 +421,9 @@ def _create_table(cursor,
397421
dataframe_type=dataframe_type,
398422
preserve_index=preserve_index,
399423
engine=engine,
400-
columns=columns)
424+
columns=columns,
425+
varchar_default_length=varchar_default_length,
426+
varchar_lengths=varchar_lengths)
401427
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
402428
sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
403429
logger.debug(f"Create table query:\n{sql}")
@@ -408,7 +434,10 @@ def _get_schema(dataframe,
408434
dataframe_type: str,
409435
preserve_index: bool,
410436
engine: str = "mysql",
411-
columns: Optional[List[str]] = None) -> List[Tuple[str, str]]:
437+
columns: Optional[List[str]] = None,
438+
varchar_default_length: int = 256,
439+
varchar_lengths: Optional[Dict[str, int]] = None) -> List[Tuple[str, str]]:
440+
varchar_lengths = {} if varchar_lengths is None else varchar_lengths
412441
schema_built: List[Tuple[str, str]] = []
413442
if "postgres" in engine.lower():
414443
convert_func = data_types.pyarrow2postgres
@@ -421,7 +450,8 @@ def _get_schema(dataframe,
421450
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
422451
for name, dtype in pyarrow_schema:
423452
if columns is None or name in columns:
424-
aurora_type: str = convert_func(dtype)
453+
varchar_len = varchar_lengths.get(name, varchar_default_length)
454+
aurora_type: str = convert_func(dtype=dtype, varchar_length=varchar_len)
425455
schema_built.append((name, aurora_type))
426456
else:
427457
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")

awswrangler/data_types.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def athena2python(dtype: str) -> Optional[type]:
8181
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
8282

8383

84-
def athena2redshift(dtype: str) -> str:
84+
def athena2redshift(dtype: str, varchar_length: int = 256) -> str:
8585
dtype = dtype.lower()
8686
if dtype == "smallint":
8787
return "SMALLINT"
@@ -96,7 +96,7 @@ def athena2redshift(dtype: str) -> str:
9696
elif dtype in ("boolean", "bool"):
9797
return "BOOL"
9898
elif dtype in ("string", "char", "varchar", "array", "row", "map"):
99-
return "VARCHAR(256)"
99+
return f"VARCHAR({varchar_length})"
100100
elif dtype == "timestamp":
101101
return "TIMESTAMP"
102102
elif dtype == "date":
@@ -125,7 +125,7 @@ def pandas2athena(dtype: str) -> str:
125125
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
126126

127127

128-
def pandas2redshift(dtype: str) -> str:
128+
def pandas2redshift(dtype: str, varchar_length: int = 256) -> str:
129129
dtype = dtype.lower()
130130
if dtype == "int32":
131131
return "INTEGER"
@@ -138,7 +138,7 @@ def pandas2redshift(dtype: str) -> str:
138138
elif dtype == "bool":
139139
return "BOOLEAN"
140140
elif dtype == "object" and isinstance(dtype, str):
141-
return "VARCHAR(256)"
141+
return f"VARCHAR({varchar_length})"
142142
elif dtype[:10] == "datetime64":
143143
return "TIMESTAMP"
144144
else:
@@ -177,7 +177,7 @@ def pyarrow2athena(dtype: pa.types) -> str:
177177
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
178178

179179

180-
def pyarrow2redshift(dtype: pa.types) -> str:
180+
def pyarrow2redshift(dtype: pa.types, varchar_length: int = 256) -> str:
181181
dtype_str = str(dtype).lower()
182182
if dtype_str == "int16":
183183
return "SMALLINT"
@@ -192,7 +192,7 @@ def pyarrow2redshift(dtype: pa.types) -> str:
192192
elif dtype_str == "bool":
193193
return "BOOLEAN"
194194
elif dtype_str == "string":
195-
return "VARCHAR(256)"
195+
return f"VARCHAR({varchar_length})"
196196
elif dtype_str.startswith("timestamp"):
197197
return "TIMESTAMP"
198198
elif dtype_str.startswith("date"):
@@ -203,7 +203,7 @@ def pyarrow2redshift(dtype: pa.types) -> str:
203203
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
204204

205205

206-
def pyarrow2postgres(dtype: pa.types) -> str:
206+
def pyarrow2postgres(dtype: pa.types, varchar_length: int = 256) -> str:
207207
dtype_str = str(dtype).lower()
208208
if dtype_str == "int16":
209209
return "SMALLINT"
@@ -218,7 +218,7 @@ def pyarrow2postgres(dtype: pa.types) -> str:
218218
elif dtype_str == "bool":
219219
return "BOOLEAN"
220220
elif dtype_str == "string":
221-
return "VARCHAR(256)"
221+
return f"VARCHAR({varchar_length})"
222222
elif dtype_str.startswith("timestamp"):
223223
return "TIMESTAMP"
224224
elif dtype_str.startswith("date"):
@@ -229,7 +229,7 @@ def pyarrow2postgres(dtype: pa.types) -> str:
229229
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
230230

231231

232-
def pyarrow2mysql(dtype: pa.types) -> str:
232+
def pyarrow2mysql(dtype: pa.types, varchar_length: int = 256) -> str:
233233
dtype_str = str(dtype).lower()
234234
if dtype_str == "int16":
235235
return "SMALLINT"
@@ -244,7 +244,7 @@ def pyarrow2mysql(dtype: pa.types) -> str:
244244
elif dtype_str == "bool":
245245
return "BOOLEAN"
246246
elif dtype_str == "string":
247-
return "VARCHAR(256)"
247+
return f"VARCHAR({varchar_length})"
248248
elif dtype_str.startswith("timestamp"):
249249
return "TIMESTAMP"
250250
elif dtype_str.startswith("date"):
@@ -321,7 +321,7 @@ def redshift2pyarrow(dtype: str) -> str:
321321
raise UnsupportedType(f"Unsupported Redshift type: {dtype_str}")
322322

323323

324-
def spark2redshift(dtype: str) -> str:
324+
def spark2redshift(dtype: str, varchar_length: int = 256) -> str:
325325
dtype = dtype.lower()
326326
if dtype == "smallint":
327327
return "SMALLINT"
@@ -340,7 +340,7 @@ def spark2redshift(dtype: str) -> str:
340340
elif dtype == "date":
341341
return "DATE"
342342
elif dtype == "string":
343-
return "VARCHAR(256)"
343+
return f"VARCHAR({varchar_length})"
344344
elif dtype.startswith("decimal"):
345345
return dtype.replace(" ", "").upper()
346346
else:

0 commit comments

Comments
 (0)