Skip to content

Commit ccff26c

Browse files
authored
No longer hardcode provider (#21)
1 parent c1d93dc commit ccff26c

File tree

10 files changed

+83
-10
lines changed

10 files changed

+83
-10
lines changed

src/firebolt/common/constants.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/firebolt/model/instance_type.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
from pydantic import Field
55

6-
from firebolt.common.constants import AWS_PROVIDER_ID
76
from firebolt.model import FireboltBaseModel
87

98

109
class InstanceTypeKey(FireboltBaseModel, frozen=True): # type: ignore
11-
provider_id: str = AWS_PROVIDER_ID
10+
provider_id: str
1211
region_id: str
1312
instance_type_id: str
1413

src/firebolt/model/provider.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from datetime import datetime
2+
from typing import Optional
3+
4+
from pydantic import Field
5+
6+
from firebolt.model import FireboltBaseModel
7+
8+
9+
class Provider(FireboltBaseModel, frozen=True): # type: ignore
10+
provider_id: str = Field(alias="id")
11+
name: str
12+
13+
# optional
14+
create_time: Optional[datetime]
15+
display_name: Optional[str]
16+
last_update_time: Optional[datetime]

src/firebolt/model/region.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33

44
from pydantic import Field
55

6-
from firebolt.common.constants import AWS_PROVIDER_ID
76
from firebolt.model import FireboltBaseModel
87

98

109
class RegionKey(FireboltBaseModel, frozen=True): # type: ignore
11-
provider_id: str = AWS_PROVIDER_ID
10+
provider_id: str
1211
region_id: str
1312

1413

src/firebolt/service/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx
44
from firebolt.common import Settings
5+
from firebolt.service.provider import get_provider_id
56

67

78
class ResourceManager:
@@ -47,6 +48,7 @@ def _init_services(self, default_region_name: str) -> None:
4748
resource_manager=self, default_region_name=default_region_name
4849
)
4950
self.instance_types = InstanceTypeService(resource_manager=self)
51+
self.provider_id = get_provider_id(client=self.client)
5052

5153
# Firebolt Resources
5254
self.databases = DatabaseService(resource_manager=self)

src/firebolt/service/provider.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from firebolt.client import Client
2+
from firebolt.model.provider import Provider
3+
4+
5+
def get_provider_id(client: Client) -> str:
6+
"""Get the AWS provider_id."""
7+
response = client.get(url="/compute/v1/providers")
8+
providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]]
9+
return providers[0].provider_id

src/firebolt/service/region.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,8 @@ def get_by_key(self, region_key: RegionKey) -> Region:
5757

5858
def get_by_id(self, region_id: str) -> Region:
5959
"""Get an AWS Region by region_id."""
60-
return self.get_by_key(RegionKey(region_id=region_id))
60+
return self.get_by_key(
61+
RegionKey(
62+
provider_id=self.resource_manager.provider_id, region_id=region_id
63+
)
64+
)

tests/conftest.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from firebolt.common.settings import Settings
1010
from firebolt.model.instance_type import InstanceType, InstanceTypeKey
11+
from firebolt.model.provider import Provider
1112
from firebolt.model.region import Region, RegionKey
13+
from tests.util import list_to_paginated_response
1214

1315

1416
@pytest.fixture
@@ -27,19 +29,34 @@ def access_token() -> str:
2729

2830

2931
@pytest.fixture
30-
def region_1() -> Region:
32+
def provider() -> Provider:
33+
return Provider(
34+
provider_id="mock_provider_id",
35+
name="mock_provider_name",
36+
)
37+
38+
39+
@pytest.fixture
40+
def mock_providers(provider) -> list[Provider]:
41+
return [provider]
42+
43+
44+
@pytest.fixture
45+
def region_1(provider) -> Region:
3146
return Region(
3247
key=RegionKey(
48+
provider_id=provider.provider_id,
3349
region_id="mock_region_id_1",
3450
),
3551
name="mock_region_1",
3652
)
3753

3854

3955
@pytest.fixture
40-
def region_2() -> Region:
56+
def region_2(provider) -> Region:
4157
return Region(
4258
key=RegionKey(
59+
provider_id=provider.provider_id,
4360
region_id="mock_region_id_2",
4461
),
4562
name="mock_region_2",
@@ -52,9 +69,10 @@ def mock_regions(region_1, region_2) -> list[Region]:
5269

5370

5471
@pytest.fixture
55-
def instance_type_1(region_1) -> InstanceType:
72+
def instance_type_1(provider, region_1) -> InstanceType:
5673
return InstanceType(
5774
key=InstanceTypeKey(
75+
provider_id=provider.provider_id,
5876
region_id=region_1.key.region_id,
5977
instance_type_id="instance_type_id_1",
6078
),
@@ -63,9 +81,10 @@ def instance_type_1(region_1) -> InstanceType:
6381

6482

6583
@pytest.fixture
66-
def instance_type_2(region_2) -> InstanceType:
84+
def instance_type_2(provider, region_2) -> InstanceType:
6785
return InstanceType(
6886
key=InstanceTypeKey(
87+
provider_id=provider.provider_id,
6988
region_id=region_2.key.region_id,
7089
instance_type_id="instance_type_id_2",
7190
),
@@ -108,6 +127,26 @@ def auth_url(settings: Settings) -> str:
108127
return f"https://{settings.server}/auth/v1/login"
109128

110129

130+
@pytest.fixture
131+
def provider_callback(provider_url: str, mock_providers) -> Callable:
132+
def do_mock(
133+
request: httpx.Request = None,
134+
**kwargs,
135+
) -> Response:
136+
assert request.url == provider_url
137+
return to_response(
138+
status_code=httpx.codes.OK,
139+
json=list_to_paginated_response(mock_providers),
140+
)
141+
142+
return do_mock
143+
144+
145+
@pytest.fixture
146+
def provider_url(settings: Settings) -> str:
147+
return f"https://{settings.server}/compute/v1/providers"
148+
149+
111150
@pytest.fixture
112151
def db_name() -> str:
113152
return "database"

tests/model/test_instance_type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
def test_instance_type(
1313
httpx_mock: HTTPXMock,
1414
auth_callback: Callable,
15+
provider_callback: Callable,
1516
settings: Settings,
1617
mock_instance_types: List[InstanceType],
1718
):
19+
httpx_mock.add_callback(auth_callback)
20+
httpx_mock.add_callback(provider_callback)
1821
httpx_mock.add_callback(auth_callback)
1922
httpx_mock.add_response(
2023
url=f"https://{settings.server}/compute/v1/instanceTypes?page.first=5000",

tests/model/test_region.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
def test_region(
1313
httpx_mock: HTTPXMock,
1414
auth_callback: Callable,
15+
provider_callback: Callable,
1516
settings: Settings,
1617
mock_regions: List[Region],
1718
):
19+
httpx_mock.add_callback(auth_callback)
20+
httpx_mock.add_callback(provider_callback)
1821
httpx_mock.add_callback(auth_callback)
1922
httpx_mock.add_response(
2023
url=f"https://{settings.server}/compute/v1/regions?page.first=5000",

0 commit comments

Comments
 (0)