11from __future__ import annotations
22
33from json import JSONDecodeError
4+ from socket import IPPROTO_TCP , SO_KEEPALIVE , SOL_SOCKET , TCP_KEEPIDLE
45from types import TracebackType
56from typing import Callable , List , Optional , Type
67
7- from httpx import HTTPStatusError , RequestError , Timeout
8+ from httpcore .backends .auto import AutoBackend
9+ from httpcore .backends .base import AsyncNetworkStream
10+ from httpx import AsyncHTTPTransport , HTTPStatusError , RequestError , Timeout
811
912from firebolt .async_db .cursor import BaseCursor , Cursor
1013from firebolt .client import DEFAULT_API_URL , AsyncClient
1720from firebolt .common .util import fix_url_schema
1821
1922DEFAULT_TIMEOUT_SECONDS : int = 5
23+ KEEPALIVE_FLAG : int = 1
24+ KEEPIDLE_RATE : int = 60 # seconds
2025
2126
2227async def _resolve_engine_url (
@@ -131,6 +136,36 @@ async def connect_inner(
131136 return connect_inner
132137
133138
139+ class OverriddenHttpBackend (AutoBackend ):
140+ """
141+ This class is a short-term solution for TCP keep-alive issue:
142+ https://docs.aws.amazon.com/elasticloadbalancing/latest/network/network-load-balancers.html#connection-idle-timeout
143+ Since httpx creates a connection right before executing a request
144+ backend has to be overridden in order to set the socket KEEPALIVE
145+ and KEEPIDLE settings.
146+ """
147+
148+ async def connect_tcp (
149+ self ,
150+ host : str ,
151+ port : int ,
152+ timeout : Optional [float ] = None ,
153+ local_address : Optional [str ] = None ,
154+ ) -> AsyncNetworkStream :
155+ stream = await super ().connect_tcp (
156+ host , port , timeout = timeout , local_address = local_address
157+ )
158+ # Enable keepalive
159+ stream .get_extra_info ("socket" ).setsockopt (
160+ SOL_SOCKET , SO_KEEPALIVE , KEEPALIVE_FLAG
161+ )
162+ # Set keepalive to 60 seconds
163+ stream .get_extra_info ("socket" ).setsockopt (
164+ IPPROTO_TCP , TCP_KEEPIDLE , KEEPIDLE_RATE
165+ )
166+ return stream
167+
168+
134169class BaseConnection :
135170 client_class : type
136171 cursor_class : type
@@ -151,11 +186,14 @@ def __init__(
151186 password : str ,
152187 api_endpoint : str = DEFAULT_API_URL ,
153188 ):
189+ transport = AsyncHTTPTransport ()
190+ transport ._pool ._network_backend = OverriddenHttpBackend ()
154191 self ._client = AsyncClient (
155192 auth = (username , password ),
156193 base_url = engine_url ,
157194 api_endpoint = api_endpoint ,
158195 timeout = Timeout (DEFAULT_TIMEOUT_SECONDS , read = None ),
196+ transport = transport ,
159197 )
160198 self .api_endpoint = api_endpoint
161199 self .engine_url = engine_url
0 commit comments