Skip to content

Commit 6172b4c

Browse files
authored
Add sql server support (#498)
1 parent 89deacf commit 6172b4c

25 files changed

+864
-85
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# A comma-separated list of package or module names from where C extensions may
44
# be loaded. Extensions are loading into the active Python interpreter and may
55
# run arbitrary code.
6-
extension-pkg-whitelist=pyarrow.lib
6+
extension-pkg-whitelist=pyarrow.lib,pyodbc
77

88
# Specify a score threshold to be exceeded before program exits with error.
99
fail-under=10

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3
105105
- [004 - Parquet Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/004%20-%20Parquet%20Datasets.ipynb)
106106
- [005 - Glue Catalog](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/005%20-%20Glue%20Catalog.ipynb)
107107
- [006 - Amazon Athena](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/006%20-%20Amazon%20Athena.ipynb)
108-
- [007 - Databases (Redshift, MySQL and PostgreSQL)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL.ipynb)
108+
- [007 - Databases (Redshift, MySQL, PostgreSQL and SQL Server)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL%2C%20SQL%20Server.ipynb)
109109
- [008 - Redshift - Copy & Unload.ipynb](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/008%20-%20Redshift%20-%20Copy%20%26%20Unload.ipynb)
110110
- [009 - Redshift - Append, Overwrite and Upsert](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/009%20-%20Redshift%20-%20Append%2C%20Overwrite%2C%20Upsert.ipynb)
111111
- [010 - Parquet Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/010%20-%20Parquet%20Crawler.ipynb)
@@ -134,6 +134,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3
134134
- [Amazon Redshift](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-redshift)
135135
- [PostgreSQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#postgresql)
136136
- [MySQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#mysql)
137+
- [SQL Server](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#sqlserver)
137138
- [DynamoDB](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#dynamodb)
138139
- [Amazon Timestream](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-timestream)
139140
- [Amazon EMR](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-emr)

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
redshift,
2222
s3,
2323
secretsmanager,
24+
sqlserver,
2425
sts,
2526
timestream,
2627
)
@@ -42,6 +43,7 @@
4243
"mysql",
4344
"postgresql",
4445
"secretsmanager",
46+
"sqlserver",
4547
"config",
4648
"timestream",
4749
"__description__",

awswrangler/_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla
154154
if _Config._is_null(value=value):
155155
if nullable is True:
156156
return None
157-
exceptions.InvalidArgumentValue(f"{name} configuration does not accept a null value. Please pass {dtype}.")
157+
raise exceptions.InvalidArgumentValue(
158+
f"{name} configuration does not accept a null value. Please pass {dtype}."
159+
)
158160
try:
159161
return dtype(value) if isinstance(value, dtype) is False else value
160162
except ValueError as ex:

awswrangler/_data_types.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,41 @@ def pyarrow2postgresql( # pylint: disable=too-many-branches,too-many-return-sta
166166
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")
167167

168168

169+
def pyarrow2sqlserver( # pylint: disable=too-many-branches,too-many-return-statements
170+
dtype: pa.DataType, string_type: str
171+
) -> str:
172+
"""Pyarrow to Microsoft SQL Server data types conversion."""
173+
if pa.types.is_int8(dtype):
174+
return "SMALLINT"
175+
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
176+
return "SMALLINT"
177+
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
178+
return "INT"
179+
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
180+
return "BIGINT"
181+
if pa.types.is_uint64(dtype):
182+
raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.")
183+
if pa.types.is_float32(dtype):
184+
return "FLOAT(24)"
185+
if pa.types.is_float64(dtype):
186+
return "FLOAT"
187+
if pa.types.is_boolean(dtype):
188+
return "BIT"
189+
if pa.types.is_string(dtype):
190+
return string_type
191+
if pa.types.is_timestamp(dtype):
192+
return "DATETIME2"
193+
if pa.types.is_date(dtype):
194+
return "DATE"
195+
if pa.types.is_decimal(dtype):
196+
return f"DECIMAL({dtype.precision},{dtype.scale})"
197+
if pa.types.is_dictionary(dtype):
198+
return pyarrow2sqlserver(dtype=dtype.value_type, string_type=string_type)
199+
if pa.types.is_binary(dtype):
200+
return "VARBINARY"
201+
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")
202+
203+
169204
def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements
170205
"""Pyarrow to Amazon Timestream data types conversion."""
171206
if pa.types.is_int8(dtype):

awswrangler/_databases.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def _get_connection_attributes_from_catalog(
3636
details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[
3737
"ConnectionProperties"
3838
]
39-
port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split("/")
39+
if ";databaseName=" in details["JDBC_CONNECTION_URL"]:
40+
database_sep = ";databaseName="
41+
else:
42+
database_sep = "/"
43+
port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split(database_sep)
4044
return ConnectionAttributes(
4145
kind=details["JDBC_CONNECTION_URL"].split(":")[1].lower(),
4246
user=details["USERNAME"],
@@ -136,19 +140,48 @@ def _records2df(
136140
return df
137141

138142

139-
def _iterate_cursor(
140-
cursor: Any,
143+
def _get_cols_names(cursor_description: Any) -> List[str]:
144+
cols_names = [col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor_description]
145+
_logger.debug("cols_names: %s", cols_names)
146+
147+
return cols_names
148+
149+
150+
def _iterate_results(
151+
con: Any,
152+
cursor_args: List[Any],
141153
chunksize: int,
142-
cols_names: List[str],
143-
index: Optional[Union[str, List[str]]],
154+
index_col: Optional[Union[str, List[str]]],
144155
safe: bool,
145156
dtype: Optional[Dict[str, pa.DataType]],
146157
) -> Iterator[pd.DataFrame]:
147-
while True:
148-
records = cursor.fetchmany(chunksize)
149-
if not records:
150-
break
151-
yield _records2df(records=records, cols_names=cols_names, index=index, safe=safe, dtype=dtype)
158+
with con.cursor() as cursor:
159+
cursor.execute(*cursor_args)
160+
cols_names = _get_cols_names(cursor.description)
161+
while True:
162+
records = cursor.fetchmany(chunksize)
163+
if not records:
164+
break
165+
yield _records2df(records=records, cols_names=cols_names, index=index_col, safe=safe, dtype=dtype)
166+
167+
168+
def _fetch_all_results(
169+
con: Any,
170+
cursor_args: List[Any],
171+
index_col: Optional[Union[str, List[str]]] = None,
172+
dtype: Optional[Dict[str, pa.DataType]] = None,
173+
safe: bool = True,
174+
) -> pd.DataFrame:
175+
with con.cursor() as cursor:
176+
cursor.execute(*cursor_args)
177+
cols_names = _get_cols_names(cursor.description)
178+
return _records2df(
179+
records=cast(List[Tuple[Any]], cursor.fetchall()),
180+
cols_names=cols_names,
181+
index=index_col,
182+
dtype=dtype,
183+
safe=safe,
184+
)
152185

153186

154187
def read_sql_query(
@@ -163,23 +196,23 @@ def read_sql_query(
163196
"""Read SQL Query (generic)."""
164197
args = _convert_params(sql, params)
165198
try:
166-
with con.cursor() as cursor:
167-
cursor.execute(*args)
168-
cols_names: List[str] = [
169-
col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description
170-
]
171-
_logger.debug("cols_names: %s", cols_names)
172-
if chunksize is None:
173-
return _records2df(
174-
records=cast(List[Tuple[Any]], cursor.fetchall()),
175-
cols_names=cols_names,
176-
index=index_col,
177-
dtype=dtype,
178-
safe=safe,
179-
)
180-
return _iterate_cursor(
181-
cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe
199+
if chunksize is None:
200+
return _fetch_all_results(
201+
con=con,
202+
cursor_args=args,
203+
index_col=index_col,
204+
dtype=dtype,
205+
safe=safe,
182206
)
207+
208+
return _iterate_results(
209+
con=con,
210+
cursor_args=args,
211+
chunksize=chunksize,
212+
index_col=index_col,
213+
dtype=dtype,
214+
safe=safe,
215+
)
183216
except Exception as ex:
184217
con.rollback()
185218
_logger.error(ex)

awswrangler/catalog/_get.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def get_connection(
524524
client_kms = _utils.client(service_name="kms", session=boto3_session)
525525
pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(res["ConnectionProperties"]["ENCRYPTED_PASSWORD"]))[
526526
"Plaintext"
527-
]
527+
].decode("utf-8")
528528
res["ConnectionProperties"]["PASSWORD"] = pwd
529529
return cast(Dict[str, Any], res)
530530

awswrangler/mysql.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def connect(
127127
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
128128
)
129129
if attrs.kind != "mysql":
130-
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")
130+
raise exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")
131131
return pymysql.connect(
132132
user=attrs.user,
133133
database=attrs.database,
@@ -156,8 +156,7 @@ def read_sql_query(
156156
sql : str
157157
SQL query.
158158
con : pymysql.connections.Connection
159-
Use pymysql.connect() to use "
160-
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
159+
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
161160
index_col : Union[str, List[str]], optional
162161
Column(s) to set as index(MultiIndex).
163162
params : Union[List, Tuple, Dict], optional
@@ -214,8 +213,7 @@ def read_sql_table(
214213
table : str
215214
Table name.
216215
con : pymysql.connections.Connection
217-
Use pymysql.connect() to use "
218-
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
216+
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
219217
schema : str, optional
220218
Name of SQL schema in database to query.
221219
Uses default schema if None.
@@ -276,8 +274,7 @@ def to_sql(
276274
df : pandas.DataFrame
277275
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
278276
con : pymysql.connections.Connection
279-
Use pymysql.connect() to use "
280-
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
277+
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
281278
table : str
282279
Table name
283280
schema : str

awswrangler/postgresql.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def connect(
131131
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
132132
)
133133
if attrs.kind != "postgresql":
134-
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)")
134+
raise exceptions.InvalidDatabaseType(
135+
f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)"
136+
)
135137
return pg8000.connect(
136138
user=attrs.user,
137139
database=attrs.database,
@@ -160,8 +162,7 @@ def read_sql_query(
160162
sql : str
161163
SQL query.
162164
con : pg8000.Connection
163-
Use pg8000.connect() to use "
164-
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
165+
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
165166
index_col : Union[str, List[str]], optional
166167
Column(s) to set as index(MultiIndex).
167168
params : Union[List, Tuple, Dict], optional
@@ -218,8 +219,7 @@ def read_sql_table(
218219
table : str
219220
Table name.
220221
con : pg8000.Connection
221-
Use pg8000.connect() to use "
222-
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
222+
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
223223
schema : str, optional
224224
Name of SQL schema in database to query (if database flavor supports this).
225225
Uses default schema if None (default).
@@ -280,8 +280,7 @@ def to_sql(
280280
df : pandas.DataFrame
281281
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
282282
con : pg8000.Connection
283-
Use pg8000.connect() to use "
284-
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
283+
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
285284
table : str
286285
Table name
287286
schema : str
@@ -310,7 +309,7 @@ def to_sql(
310309
>>> import awswrangler as wr
311310
>>> con = wr.postgresql.connect("MY_GLUE_CONNECTION")
312311
>>> wr.postgresql.to_sql(
313-
... df=df
312+
... df=df,
314313
... table="my_table",
315314
... schema="public",
316315
... con=con

awswrangler/redshift.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,9 @@ def connect(
386386
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
387387
)
388388
if attrs.kind != "redshift":
389-
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)")
389+
raise exceptions.InvalidDatabaseType(
390+
f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)"
391+
)
390392
return redshift_connector.connect(
391393
user=attrs.user,
392394
database=attrs.database,

0 commit comments

Comments
 (0)