Skip to content

Commit d241cca

Browse files
authored
fix: Setting TCP keepalive values (#116)
1 parent 2afec24 commit d241cca

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

src/firebolt/async_db/connection.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
from json import JSONDecodeError
4+
from socket import IPPROTO_TCP, SO_KEEPALIVE, SOL_SOCKET, TCP_KEEPIDLE
45
from types import TracebackType
56
from typing import Callable, List, Optional, Type
67

7-
from httpx import HTTPStatusError, RequestError, Timeout
8+
from httpcore.backends.auto import AutoBackend
9+
from httpcore.backends.base import AsyncNetworkStream
10+
from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout
811

912
from firebolt.async_db.cursor import BaseCursor, Cursor
1013
from firebolt.client import DEFAULT_API_URL, AsyncClient
@@ -17,6 +20,8 @@
1720
from firebolt.common.util import fix_url_schema
1821

1922
DEFAULT_TIMEOUT_SECONDS: int = 5
23+
KEEPALIVE_FLAG: int = 1
24+
KEEPIDLE_RATE: int = 60 # seconds
2025

2126

2227
async def _resolve_engine_url(
@@ -131,6 +136,36 @@ async def connect_inner(
131136
return connect_inner
132137

133138

139+
class OverriddenHttpBackend(AutoBackend):
140+
"""
141+
This class is a short-term solution for TCP keep-alive issue:
142+
https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout
143+
Since httpx creates a connection right before executing a request
144+
backend has to be overridden in order to set the socket KEEPALIVE
145+
and KEEPIDLE settings.
146+
"""
147+
148+
async def connect_tcp(
149+
self,
150+
host: str,
151+
port: int,
152+
timeout: Optional[float] = None,
153+
local_address: Optional[str] = None,
154+
) -> AsyncNetworkStream:
155+
stream = await super().connect_tcp(
156+
host, port, timeout=timeout, local_address=local_address
157+
)
158+
# Enable keepalive
159+
stream.get_extra_info("socket").setsockopt(
160+
SOL_SOCKET, SO_KEEPALIVE, KEEPALIVE_FLAG
161+
)
162+
# Set keepalive to 60 seconds
163+
stream.get_extra_info("socket").setsockopt(
164+
IPPROTO_TCP, TCP_KEEPIDLE, KEEPIDLE_RATE
165+
)
166+
return stream
167+
168+
134169
class BaseConnection:
135170
client_class: type
136171
cursor_class: type
@@ -151,11 +186,14 @@ def __init__(
151186
password: str,
152187
api_endpoint: str = DEFAULT_API_URL,
153188
):
189+
transport = AsyncHTTPTransport()
190+
transport._pool._network_backend = OverriddenHttpBackend()
154191
self._client = AsyncClient(
155192
auth=(username, password),
156193
base_url=engine_url,
157194
api_endpoint=api_endpoint,
158195
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),
196+
transport=transport,
159197
)
160198
self.api_endpoint = api_endpoint
161199
self.engine_url = engine_url

tests/integration/dbapi/async/test_queries_async.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ async def test_select(
6464
data, all_types_query_response, "Invalid data returned by fetchmany"
6565
)
6666

67+
# AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly
68+
await c.execute(
69+
"SELECT sleepEachRow(1) from numbers(360)",
70+
set_parameters={"advanced_mode": "1", "use_standard_sql": "0"},
71+
)
72+
data = await c.fetchall()
73+
assert len(data) == 360, "Invalid data size returned by fetchall"
74+
6775

6876
@mark.asyncio
6977
async def test_drop_create(

tests/integration/dbapi/sync/test_queries.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def test_select(
6161
data, all_types_query_response, "Invalid data returned by fetchmany"
6262
)
6363

64+
# AWS ALB TCP timeout set to 350, make sure we handle the keepalive correctly
65+
c.execute(
66+
"SELECT sleepEachRow(1) from numbers(360)",
67+
set_parameters={"advanced_mode": "1", "use_standard_sql": "0"},
68+
)
69+
data = c.fetchall()
70+
assert len(data) == 360, "Invalid data size returned by fetchall"
71+
6472

6573
def test_drop_create(
6674
connection: Connection, create_drop_description: List[Column]

0 commit comments

Comments
 (0)