Skip to content

Commit 82fe7f9

Browse files
feat: Fir 10511 feature request support resour (#128)
* add auth_token settings field * add auth.from_token method * allow passing access_token to resource manager * allow connecting with access token * add tests for access_token connection * extend unit tests * add ConfigurationError
1 parent 2309045 commit 82fe7f9

File tree

13 files changed

+334
-109
lines changed

13 files changed

+334
-109
lines changed

src/firebolt/async_db/connection.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from httpcore.backends.auto import AutoBackend
99
from httpcore.backends.base import AsyncNetworkStream
1010
from httpx import AsyncHTTPTransport, HTTPStatusError, RequestError, Timeout
11+
from httpx._types import AuthTypes
1112

1213
from firebolt.async_db.cursor import BaseCursor, Cursor
13-
from firebolt.client import DEFAULT_API_URL, AsyncClient
14+
from firebolt.client import DEFAULT_API_URL, AsyncClient, Auth
1415
from firebolt.common.exception import (
16+
ConfigurationError,
1517
ConnectionClosedError,
1618
FireboltEngineError,
1719
InterfaceError,
@@ -26,13 +28,12 @@
2628

2729
async def _resolve_engine_url(
2830
engine_name: str,
29-
username: str,
30-
password: str,
31+
auth: AuthTypes,
3132
api_endpoint: str,
3233
account_name: Optional[str] = None,
3334
) -> str:
3435
async with AsyncClient(
35-
auth=(username, password),
36+
auth=auth,
3637
base_url=api_endpoint,
3738
account_name=account_name,
3839
api_endpoint=api_endpoint,
@@ -64,11 +65,43 @@ async def _resolve_engine_url(
6465
raise InterfaceError(f"Unable to retrieve engine endpoint: {e}.")
6566

6667

68+
def _validate_engine_name_and_url(
69+
engine_name: Optional[str], engine_url: Optional[str]
70+
) -> None:
71+
if engine_name and engine_url:
72+
raise ConfigurationError(
73+
"Both engine_name and engine_url are provided. Provide only one to connect."
74+
)
75+
if not engine_name and not engine_url:
76+
raise ConfigurationError(
77+
"Neither engine_name nor engine_url is provided. Provide one to connect."
78+
)
79+
80+
81+
def _get_auth(
82+
username: Optional[str], password: Optional[str], access_token: Optional[str]
83+
) -> AuthTypes:
84+
if not access_token:
85+
if not username or not password:
86+
raise ConfigurationError(
87+
"Neither username/password nor access_token are provided. Provide one"
88+
" to authenticate"
89+
)
90+
return (username, password)
91+
elif username or password:
92+
raise ConfigurationError(
93+
"Either username/password and access_token are provided. Provide only one"
94+
" to authenticate"
95+
)
96+
return Auth.from_token(access_token)
97+
98+
6799
def async_connect_factory(connection_class: Type) -> Callable:
68100
async def connect_inner(
69101
database: str = None,
70-
username: str = None,
71-
password: str = None,
102+
username: Optional[str] = None,
103+
password: Optional[str] = None,
104+
access_token: Optional[str] = None,
72105
engine_name: Optional[str] = None,
73106
engine_url: Optional[str] = None,
74107
account_name: Optional[str] = None,
@@ -90,48 +123,31 @@ async def connect_inner(
90123
Either `engine_name` or `engine_url` should be provided, but not both.
91124
92125
"""
93-
94-
if engine_name and engine_url:
95-
raise InterfaceError(
96-
"Both engine_name and engine_url are provided. "
97-
"Provide only one to connect."
98-
)
99-
if not engine_name and not engine_url:
100-
raise InterfaceError(
101-
"Neither engine_name nor engine_url is provided. "
102-
"Provide one to connect."
103-
)
104-
105-
api_endpoint = fix_url_schema(api_endpoint)
106126
# These parameters are optional in function signature
107127
# but are required to connect.
108128
# PEP 249 recommends making them kwargs.
109-
for param, name in (
110-
(database, "database"),
111-
(username, "username"),
112-
(password, "password"),
113-
):
114-
if not param:
115-
raise InterfaceError(f"{name} is required to connect.")
129+
if not database:
130+
raise ConfigurationError("database name is required to connect.")
131+
132+
_validate_engine_name_and_url(engine_name, engine_url)
133+
auth = _get_auth(username, password, access_token)
134+
api_endpoint = fix_url_schema(api_endpoint)
116135

117136
# Mypy checks, this should never happen
118137
assert database is not None
119-
assert username is not None
120-
assert password is not None
121138

122139
if engine_name:
123140
engine_url = await _resolve_engine_url(
124141
engine_name=engine_name,
125-
username=username,
126-
password=password,
142+
auth=auth,
127143
account_name=account_name,
128144
api_endpoint=api_endpoint,
129145
)
130146

131147
assert engine_url is not None
132148

133149
engine_url = fix_url_schema(engine_url)
134-
return connection_class(engine_url, database, username, password, api_endpoint)
150+
return connection_class(engine_url, database, auth, api_endpoint)
135151

136152
return connect_inner
137153

@@ -187,15 +203,16 @@ class BaseConnection:
187203
def __init__(
188204
self,
189205
engine_url: str,
190-
database: str, # TODO: Get by engine name
191-
username: str,
192-
password: str,
206+
database: str,
207+
auth: AuthTypes,
193208
api_endpoint: str = DEFAULT_API_URL,
194209
):
210+
# Override tcp keepalive settings for connection
195211
transport = AsyncHTTPTransport()
196212
transport._pool._network_backend = OverriddenHttpBackend()
213+
197214
self._client = AsyncClient(
198-
auth=(username, password),
215+
auth=auth,
199216
base_url=engine_url,
200217
api_endpoint=api_endpoint,
201218
timeout=Timeout(DEFAULT_TIMEOUT_SECONDS, read=None),

src/firebolt/client/auth.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ class Auth(HttpxAuth):
2727

2828
requires_response_body = True
2929

30+
@staticmethod
31+
def from_token(token: str) -> "Auth":
32+
a = Auth("", "")
33+
a._token = token
34+
return a
35+
3036
def __init__(
3137
self, username: str, password: str, api_endpoint: str = DEFAULT_API_URL
3238
):

src/firebolt/common/exception.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, method_name: str):
1717
def __str__(self) -> str:
1818
return (
1919
f"Unable to call {self.method_name}: "
20-
f"Engine must to be attached to a database first."
20+
"Engine must to be attached to a database first."
2121
)
2222

2323

@@ -44,7 +44,7 @@ def __init__(self, method_name: str):
4444
def __str__(self) -> str:
4545
return (
4646
f"Unable to call {self.method_name}: "
47-
f"Engine must not be in starting or stopping state."
47+
"Engine must not be in starting or stopping state."
4848
)
4949

5050

@@ -159,3 +159,9 @@ class NotSupportedError(DatabaseError):
159159
Exception raised when the database encounters an internal error,
160160
e.g. the cursor is not valid anymore, the transaction is out of sync, etc.
161161
"""
162+
163+
164+
class ConfigurationError(InterfaceError):
165+
"""
166+
Exception raised when provided configuration is not correct
167+
"""

src/firebolt/common/settings.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1-
from pydantic import BaseSettings, Field, SecretStr
1+
from pydantic import BaseSettings, Field, SecretStr, root_validator
22

33

44
class Settings(BaseSettings):
5+
# Authorization
6+
user: str = Field(None, env="FIREBOLT_USER")
7+
password: SecretStr = Field(None, env="FIREBOLT_PASSWORD")
8+
# Or
9+
access_token: str = Field(None, env="FIREBOLT_AUTH_TOKEN")
10+
511
account_name: str = Field(None, env="FIREBOLT_ACCOUNT")
612
server: str = Field(..., env="FIREBOLT_SERVER")
7-
user: str = Field(..., env="FIREBOLT_USER")
8-
password: SecretStr = Field(..., env="FIREBOLT_PASSWORD")
913
default_region: str = Field(..., env="FIREBOLT_DEFAULT_REGION")
1014

1115
class Config:
1216
env_file = ".env"
17+
18+
@root_validator
19+
def mutual_exclusive_with_creds(cls, values: dict) -> dict:
20+
if values["user"] or values["password"]:
21+
if values["access_token"]:
22+
raise ValueError("Provide only one of user/password or access_token")
23+
elif not values["access_token"]:
24+
raise ValueError("Provide either user/password or access_token")
25+
return values

src/firebolt/service/manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
from typing import Optional
22

33
from httpx import Timeout
4+
from httpx._types import AuthTypes
45

5-
from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx
6+
from firebolt.client import (
7+
Auth,
8+
Client,
9+
log_request,
10+
log_response,
11+
raise_on_4xx_5xx,
12+
)
613
from firebolt.common import Settings
714
from firebolt.service.provider import get_provider_id
815

@@ -25,13 +32,17 @@ class ResourceManager:
2532
- instance types (AWS instance types which engines can use)
2633
"""
2734

28-
def __init__(
29-
self, settings: Optional[Settings] = None, account_name: Optional[str] = None
30-
):
35+
def __init__(self, settings: Optional[Settings] = None):
3136
self.settings = settings or Settings()
3237

38+
auth: AuthTypes = None
39+
if self.settings.access_token:
40+
auth = Auth.from_token(self.settings.access_token)
41+
else:
42+
auth = (self.settings.user, self.settings.password.get_secret_value())
43+
3344
self.client = Client(
34-
auth=(self.settings.user, self.settings.password.get_secret_value()),
45+
auth=auth,
3546
base_url=f"https://{ self.settings.server}",
3647
account_name=self.settings.account_name,
3748
api_endpoint=self.settings.server,

tests/unit/async_db/test_connection.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from firebolt.async_db import Connection, connect
88
from firebolt.async_db._types import ColType
99
from firebolt.common.exception import (
10+
ConfigurationError,
1011
ConnectionClosedError,
1112
FireboltEngineError,
12-
InterfaceError,
1313
)
1414
from firebolt.common.settings import Settings
1515
from firebolt.common.urls import ACCOUNT_ENGINE_BY_NAME_URL
@@ -90,18 +90,49 @@ async def test_cursor_initialized(
9090

9191
@mark.asyncio
9292
async def test_connect_empty_parameters():
93-
params = ("database", "username", "password")
94-
kwargs = {"engine_url": "engine_url", **{p: p for p in params}}
93+
with raises(ConfigurationError):
94+
async with await connect(engine_url="engine_url"):
95+
pass
96+
97+
98+
@mark.asyncio
99+
async def test_connect_access_token(
100+
settings: Settings,
101+
db_name: str,
102+
httpx_mock: HTTPXMock,
103+
auth_callback: Callable,
104+
auth_url: str,
105+
check_token_callback: Callable,
106+
query_url: str,
107+
python_query_data: List[List[ColType]],
108+
access_token: str,
109+
):
110+
httpx_mock.add_callback(check_token_callback, url=query_url)
111+
async with (
112+
await connect(
113+
engine_url=settings.server,
114+
database=db_name,
115+
access_token=access_token,
116+
account_name="a",
117+
api_endpoint=settings.server,
118+
)
119+
) as connection:
120+
cursor = connection.cursor()
121+
assert await cursor.execute("select*") == -1
122+
123+
with raises(ConfigurationError):
124+
async with await connect(engine_url="engine_url", database="database"):
125+
pass
95126

96-
for param in params:
97-
with raises(InterfaceError) as exc_info:
98-
kwargs = {
99-
"engine_url": "engine_url",
100-
**{p: p for p in params if p != param},
101-
}
102-
async with await connect(**kwargs):
103-
pass
104-
assert str(exc_info.value) == f"{param} is required to connect."
127+
with raises(ConfigurationError):
128+
async with await connect(
129+
engine_url="engine_url",
130+
database="database",
131+
username="username",
132+
password="password",
133+
access_token="access_token",
134+
):
135+
pass
105136

106137

107138
@mark.asyncio
@@ -125,7 +156,7 @@ async def test_connect_engine_name(
125156
):
126157
"""connect properly handles engine_name"""
127158

128-
with raises(InterfaceError) as exc_info:
159+
with raises(ConfigurationError):
129160
async with await connect(
130161
engine_url="engine_url",
131162
engine_name="engine_name",
@@ -135,21 +166,15 @@ async def test_connect_engine_name(
135166
account_name="account_name",
136167
):
137168
pass
138-
assert str(exc_info.value).startswith(
139-
"Both engine_name and engine_url are provided."
140-
)
141169

142-
with raises(InterfaceError) as exc_info:
170+
with raises(ConfigurationError):
143171
async with await connect(
144172
database="db",
145173
username="username",
146174
password="password",
147175
account_name="account",
148176
):
149177
pass
150-
assert str(exc_info.value).startswith(
151-
"Neither engine_name nor engine_url is provided."
152-
)
153178

154179
httpx_mock.add_callback(auth_callback, url=auth_url)
155180
httpx_mock.add_callback(query_callback, url=query_url)
@@ -166,7 +191,7 @@ async def test_connect_engine_name(
166191
status_code=codes.NOT_FOUND,
167192
)
168193

169-
with raises(FireboltEngineError) as exc_info:
194+
with raises(FireboltEngineError):
170195
async with await connect(
171196
database="db",
172197
username="username",

0 commit comments

Comments
 (0)