Skip to content

Commit 0779fdf

Browse files
fix: Fir 11539 setting non existing account n (#175)
* validate account_name in connect * add unit tests * add test for resource manager invalid account name * resolve comments * extend comment * improve comment
1 parent f81e677 commit 0779fdf

File tree

11 files changed

+187
-55
lines changed

11 files changed

+187
-55
lines changed

src/firebolt/async_db/connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,18 @@ async def connect_inner(
222222
account_name=account_name,
223223
api_endpoint=api_endpoint,
224224
)
225+
elif account_name:
226+
# In above if branches account name is validated since it's used to
227+
# resolve or get an engine url.
228+
# We need to manually validate account_name if none of the above
229+
# cases are triggered.
230+
async with AsyncClient(
231+
auth=auth,
232+
base_url=api_endpoint,
233+
account_name=account_name,
234+
api_endpoint=api_endpoint,
235+
) as client:
236+
await client.account_id
225237

226238
assert engine_url is not None
227239

tests/unit/async_db/test_connection.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from re import Pattern
12
from typing import Callable, List
23

34
from httpx import codes
@@ -7,9 +8,10 @@
78

89
from firebolt.async_db import Connection, connect
910
from firebolt.async_db._types import ColType
10-
from firebolt.client.auth import Token, UsernamePassword
11+
from firebolt.client.auth import Auth, Token, UsernamePassword
1112
from firebolt.common.settings import Settings
1213
from firebolt.utils.exception import (
14+
AccountNotFoundError,
1315
ConfigurationError,
1416
ConnectionClosedError,
1517
FireboltEngineError,
@@ -71,7 +73,6 @@ async def test_cursor_initialized(
7173
database=db_name,
7274
username="u",
7375
password="p",
74-
account_name="a",
7576
api_endpoint=settings.server,
7677
)
7778
) as connection:
@@ -116,7 +117,6 @@ async def test_connect_access_token(
116117
engine_url=settings.server,
117118
database=db_name,
118119
access_token=access_token,
119-
account_name="a",
120120
api_endpoint=settings.server,
121121
)
122122
) as connection:
@@ -147,7 +147,7 @@ async def test_connect_engine_name(
147147
auth_url: str,
148148
query_callback: Callable,
149149
query_url: str,
150-
account_id_url: str,
150+
account_id_url: Pattern,
151151
account_id_callback: Callable,
152152
engine_id: str,
153153
get_engine_url: str,
@@ -223,7 +223,7 @@ async def test_connect_default_engine(
223223
auth_url: str,
224224
query_callback: Callable,
225225
query_url: str,
226-
account_id_url: str,
226+
account_id_url: Pattern,
227227
account_id_callback: Callable,
228228
engine_id: str,
229229
get_engine_url: str,
@@ -353,3 +353,37 @@ async def test_connect_with_auth(
353353
api_endpoint=settings.server,
354354
) as connection:
355355
await connection.cursor().execute("select*")
356+
357+
358+
@mark.asyncio
359+
async def test_connect_account_name(
360+
httpx_mock: HTTPXMock,
361+
auth: Auth,
362+
settings: Settings,
363+
db_name: str,
364+
auth_url: str,
365+
check_credentials_callback: Callable,
366+
account_id_url: Pattern,
367+
account_id_callback: Callable,
368+
):
369+
httpx_mock.add_callback(check_credentials_callback, url=auth_url)
370+
httpx_mock.add_callback(account_id_callback, url=account_id_url)
371+
372+
with raises(AccountNotFoundError):
373+
async with await connect(
374+
auth=auth,
375+
database=db_name,
376+
engine_url=settings.server,
377+
account_name="invalid",
378+
api_endpoint=settings.server,
379+
):
380+
pass
381+
382+
async with await connect(
383+
auth=auth,
384+
database=db_name,
385+
engine_url=settings.server,
386+
account_name=settings.account_name,
387+
api_endpoint=settings.server,
388+
):
389+
pass

tests/unit/client/test_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from re import Pattern
12
from typing import Callable
23

34
from httpx import codes
@@ -93,7 +94,7 @@ def test_client_account_id(
9394
test_username: str,
9495
test_password: str,
9596
account_id: str,
96-
account_id_url: str,
97+
account_id_url: Pattern,
9798
account_id_callback: Callable,
9899
auth_url: str,
99100
auth_callback: Callable,

tests/unit/client/test_client_async.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from re import Pattern
12
from typing import Callable
23

34
from httpx import codes
@@ -104,7 +105,7 @@ async def test_client_account_id(
104105
test_username: str,
105106
test_password: str,
106107
account_id: str,
107-
account_id_url: str,
108+
account_id_url: Pattern,
108109
account_id_callback: Callable,
109110
auth_url: str,
110111
auth_callback: Callable,

tests/unit/conftest.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from json import loads
2+
from re import Pattern, compile
23
from typing import Callable, List
34

45
import httpx
5-
from httpx import Response
6+
from httpx import Request, Response
67
from pydantic import SecretStr
78
from pyfakefs.fake_filesystem_unittest import Patcher
89
from pytest import fixture
910

11+
from firebolt.client.auth import Auth, UsernamePassword
1012
from firebolt.common.settings import Settings
1113
from firebolt.model.provider import Provider
1214
from firebolt.model.region import Region, RegionKey
1315
from firebolt.utils.exception import (
16+
AccountNotFoundError,
1417
DatabaseError,
1518
DataError,
1619
Error,
@@ -51,6 +54,16 @@ def global_fake_fs(request) -> None:
5154
yield
5255

5356

57+
@fixture
58+
def username() -> str:
59+
60+
61+
62+
@fixture
63+
def password() -> str:
64+
return "*****"
65+
66+
5467
@fixture
5568
def server() -> str:
5669
return "api.mock.firebolt.io"
@@ -107,16 +120,21 @@ def mock_regions(region_1, region_2) -> List[Region]:
107120

108121

109122
@fixture
110-
def settings(server, region_1) -> Settings:
123+
def settings(server: str, region_1: str, username: str, password: str) -> Settings:
111124
return Settings(
112125
server=server,
113-
114-
password=SecretStr("*****"),
126+
user=username,
127+
password=SecretStr(password),
115128
default_region=region_1.name,
116129
account_name=None,
117130
)
118131

119132

133+
@fixture
134+
def auth(username: str, password: str) -> Auth:
135+
return UsernamePassword(username, password)
136+
137+
120138
@fixture
121139
def auth_callback(auth_url: str) -> Callable:
122140
def do_mock(
@@ -148,30 +166,30 @@ def db_description() -> str:
148166

149167

150168
@fixture
151-
def account_id_url(settings: Settings) -> str:
152-
if not settings.account_name: # if None or ''
153-
return f"https://{settings.server}{ACCOUNT_URL}"
154-
else:
155-
return (
156-
f"https://{settings.server}{ACCOUNT_BY_NAME_URL}"
157-
f"?account_name={settings.account_name}"
158-
)
169+
def account_id_url(settings: Settings) -> Pattern:
170+
base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}?account_name="
171+
default_base = f"https://{settings.server}{ACCOUNT_URL}"
172+
base = base.replace("/", "\\/").replace("?", "\\?")
173+
default_base = default_base.replace("/", "\\/").replace("?", "\\?")
174+
return compile(f"(?:{base}.*|{default_base})")
159175

160176

161177
@fixture
162178
def account_id_callback(
163-
account_id: str, account_id_url: str, settings: Settings
179+
account_id: str,
180+
settings: Settings,
164181
) -> Callable:
165182
def do_mock(
166-
request: httpx.Request = None,
183+
request: Request,
167184
**kwargs,
168185
) -> Response:
169-
assert request.url == account_id_url
170-
if account_id_url.endswith(ACCOUNT_URL): # account_name shouldn't be specified.
186+
if "account_name" not in request.url.params:
171187
return Response(
172188
status_code=httpx.codes.OK, json={"account": {"id": account_id}}
173189
)
174190
# In this case, an account_name *should* be specified.
191+
if request.url.params["account_name"] != settings.account_name:
192+
raise AccountNotFoundError(request.url.params["account_name"])
175193
return Response(status_code=httpx.codes.OK, json={"account_id": account_id})
176194

177195
return do_mock
@@ -194,7 +212,7 @@ def get_engine_callback(
194212
get_engine_url: str, engine_id: str, settings: Settings
195213
) -> Callable:
196214
def do_mock(
197-
request: httpx.Request = None,
215+
request: Request = None,
198216
**kwargs,
199217
) -> Response:
200218
assert request.url == get_engine_url
@@ -230,7 +248,7 @@ def get_providers_url(settings: Settings, account_id: str, engine_id: str) -> st
230248
@fixture
231249
def get_providers_callback(get_providers_url: str, provider: Provider) -> Callable:
232250
def do_mock(
233-
request: httpx.Request = None,
251+
request: Request = None,
234252
**kwargs,
235253
) -> Response:
236254
assert request.url == get_providers_url
@@ -269,7 +287,7 @@ def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> s
269287
@fixture
270288
def database_by_name_callback(account_id: str, database_id: str) -> str:
271289
def do_mock(
272-
request: httpx.Request = None,
290+
request: Request = None,
273291
**kwargs,
274292
) -> Response:
275293
return Response(
@@ -312,7 +330,7 @@ def db_api_exceptions():
312330

313331
@fixture
314332
def check_token_callback(access_token: str) -> Callable:
315-
def check_token(request: httpx.Request = None, **kwargs) -> Response:
333+
def check_token(request: Request = None, **kwargs) -> Response:
316334
prefix = "Bearer "
317335
assert request, "empty request"
318336
assert "authorization" in request.headers, "missing authorization header"
@@ -329,7 +347,7 @@ def check_token(request: httpx.Request = None, **kwargs) -> Response:
329347
@fixture
330348
def check_credentials_callback(settings: Settings, access_token: str) -> Callable:
331349
def check_credentials(
332-
request: httpx.Request = None,
350+
request: Request = None,
333351
**kwargs,
334352
) -> Response:
335353
assert request, "empty request"

tests/unit/db/test_connection.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from re import Pattern
12
from typing import Callable, List
23

34
from httpx import codes
@@ -6,10 +7,14 @@
67
from pytest_httpx import HTTPXMock
78

89
from firebolt.async_db._types import ColType
9-
from firebolt.client.auth import Token, UsernamePassword
10+
from firebolt.client.auth import Auth, Token, UsernamePassword
1011
from firebolt.common.settings import Settings
1112
from firebolt.db import Connection, connect
12-
from firebolt.utils.exception import ConfigurationError, ConnectionClosedError
13+
from firebolt.utils.exception import (
14+
AccountNotFoundError,
15+
ConfigurationError,
16+
ConnectionClosedError,
17+
)
1318
from firebolt.utils.token_storage import TokenSecureStorage
1419
from firebolt.utils.urls import ACCOUNT_ENGINE_BY_NAME_URL
1520

@@ -104,7 +109,6 @@ def test_connect_access_token(
104109
engine_url=settings.server,
105110
database=db_name,
106111
access_token=access_token,
107-
account_name="a",
108112
api_endpoint=settings.server,
109113
)
110114
) as connection:
@@ -134,7 +138,7 @@ def test_connect_engine_name(
134138
auth_url: str,
135139
query_callback: Callable,
136140
query_url: str,
137-
account_id_url: str,
141+
account_id_url: Pattern,
138142
account_id_callback: Callable,
139143
engine_id: str,
140144
get_engine_url: str,
@@ -190,7 +194,7 @@ def test_connect_default_engine(
190194
auth_url: str,
191195
query_callback: Callable,
192196
query_url: str,
193-
account_id_url: str,
197+
account_id_url: Pattern,
194198
account_id_callback: Callable,
195199
engine_id: str,
196200
get_engine_url: str,
@@ -323,3 +327,36 @@ def test_connect_with_auth(
323327
api_endpoint=settings.server,
324328
) as connection:
325329
connection.cursor().execute("select*")
330+
331+
332+
def test_connect_account_name(
333+
httpx_mock: HTTPXMock,
334+
auth: Auth,
335+
settings: Settings,
336+
db_name: str,
337+
auth_url: str,
338+
check_credentials_callback: Callable,
339+
account_id_url: Pattern,
340+
account_id_callback: Callable,
341+
):
342+
httpx_mock.add_callback(check_credentials_callback, url=auth_url)
343+
httpx_mock.add_callback(account_id_callback, url=account_id_url)
344+
345+
with raises(AccountNotFoundError):
346+
with connect(
347+
auth=auth,
348+
database=db_name,
349+
engine_url=settings.server,
350+
account_name="invalid",
351+
api_endpoint=settings.server,
352+
):
353+
pass
354+
355+
with connect(
356+
auth=auth,
357+
database=db_name,
358+
engine_url=settings.server,
359+
account_name=settings.account_name,
360+
api_endpoint=settings.server,
361+
):
362+
pass

0 commit comments

Comments
 (0)