Skip to content

Commit 48c4dfc

Browse files
authored
Merge pull request #23 from awslabs/redshift-timeout
Add mechanism to make Redshift handle bad connections
2 parents 51ee97c + a0b1dd0 commit 48c4dfc

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

awswrangler/redshift.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,75 @@ def __init__(self, session):
3333
self._session = session
3434

3535
@staticmethod
36-
def generate_connection(database, host, port, user, password):
37-
conn = pg8000.connect(
38-
database=database,
39-
host=host,
40-
port=int(port),
41-
user=user,
42-
password=password,
43-
ssl=False,
44-
)
36+
def _validate_connection(database,
37+
host,
38+
port,
39+
user,
40+
password,
41+
tcp_keepalive=True,
42+
application_name="aws-data-wrangler-validation",
43+
validation_timeout=5):
44+
try:
45+
conn = pg8000.connect(database=database,
46+
host=host,
47+
port=int(port),
48+
user=user,
49+
password=password,
50+
ssl=True,
51+
application_name=application_name,
52+
tcp_keepalive=tcp_keepalive,
53+
timeout=validation_timeout)
54+
conn.close()
55+
except pg8000.core.InterfaceError as e:
56+
raise e
57+
58+
@staticmethod
59+
def generate_connection(database,
60+
host,
61+
port,
62+
user,
63+
password,
64+
tcp_keepalive=True,
65+
application_name="aws-data-wrangler",
66+
connection_timeout=1_200_000,
67+
statement_timeout=1_200_000,
68+
validation_timeout=5):
69+
"""
70+
Generates a valid connection object to be passed to the load_table method
71+
72+
:param database: The name of the database instance to connect with.
73+
:param host: The hostname of the Redshift server to connect with.
74+
:param port: The TCP/IP port of the Redshift server instance.
75+
:param user: The username to connect to the Redshift server with.
76+
:param password: The user password to connect to the server with.
77+
:param tcp_keepalive: If True then use TCP keepalive
78+
:param application_name: Application name
79+
:param connection_timeout: Connection Timeout
80+
:param statement_timeout: Redshift statements timeout
81+
:param validation_timeout: Timeout to try to validate the connection
82+
:return: pg8000 connection
83+
"""
84+
Redshift._validate_connection(database=database,
85+
host=host,
86+
port=port,
87+
user=user,
88+
password=password,
89+
tcp_keepalive=tcp_keepalive,
90+
application_name=application_name,
91+
validation_timeout=validation_timeout)
92+
if isinstance(type(port), str) or isinstance(type(port), float):
93+
port = int(port)
94+
conn = pg8000.connect(database=database,
95+
host=host,
96+
port=int(port),
97+
user=user,
98+
password=password,
99+
ssl=True,
100+
application_name=application_name,
101+
tcp_keepalive=tcp_keepalive,
102+
timeout=connection_timeout)
45103
cursor = conn.cursor()
46-
cursor.execute("set statement_timeout = 1200000")
104+
cursor.execute(f"set statement_timeout = {statement_timeout}")
47105
conn.commit()
48106
cursor.close()
49107
return conn

testing/test_awswrangler/test_redshift.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import boto3
66
import pandas
77
from pyspark.sql import SparkSession
8+
import pg8000
89

910
from awswrangler import Session, Redshift
1011
from awswrangler.exceptions import InvalidRedshiftDiststyle, InvalidRedshiftDistkey, InvalidRedshiftSortstyle, InvalidRedshiftSortkey
@@ -267,3 +268,33 @@ def test_write_load_manifest(session, bucket):
267268
assert manifest.get("entries")[0].get("url") == object_path
268269
assert manifest.get("entries")[0].get("mandatory")
269270
assert manifest.get("entries")[0].get("meta").get("content_length") == 2247
271+
272+
273+
def test_connection_timeout(redshift_parameters):
274+
with pytest.raises(pg8000.core.InterfaceError):
275+
Redshift.generate_connection(
276+
database="test",
277+
host=redshift_parameters.get("RedshiftAddress"),
278+
port=12345,
279+
user="test",
280+
password=redshift_parameters.get("RedshiftPassword"),
281+
)
282+
283+
284+
def test_connection_with_different_port_types(redshift_parameters):
285+
conn = Redshift.generate_connection(
286+
database="test",
287+
host=redshift_parameters.get("RedshiftAddress"),
288+
port=str(redshift_parameters.get("RedshiftPort")),
289+
user="test",
290+
password=redshift_parameters.get("RedshiftPassword"),
291+
)
292+
conn.close()
293+
conn = Redshift.generate_connection(
294+
database="test",
295+
host=redshift_parameters.get("RedshiftAddress"),
296+
port=float(redshift_parameters.get("RedshiftPort")),
297+
user="test",
298+
password=redshift_parameters.get("RedshiftPassword"),
299+
)
300+
conn.close()

0 commit comments

Comments
 (0)