Skip to content

Commit e3c221f

Browse files
authored
Allow getting by ID in /api/project/_/fleets/get (#2200)
1 parent bc5f0ac commit e3c221f

File tree

6 files changed

+111
-14
lines changed

6 files changed

+111
-14
lines changed

src/dstack/_internal/server/routers/fleets.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,13 @@ async def get_fleet(
7575
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
7676
) -> Fleet:
7777
"""
78-
Returns a fleet given a fleet name.
78+
Returns a fleet given `name` or `id`.
79+
If given `name`, does not return deleted fleets.
80+
If given `id`, returns deleted fleets.
7981
"""
8082
_, project = user_project
81-
fleet = await fleets_services.get_fleet_by_name(
82-
session=session, project=project, name=body.name
83+
fleet = await fleets_services.get_fleet(
84+
session=session, project=project, name=body.name, fleet_id=body.id
8385
)
8486
if fleet is None:
8587
raise ResourceNotExistsError()

src/dstack/_internal/server/schemas/fleets.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class ListFleetsRequest(CoreModel):
1818

1919

2020
class GetFleetRequest(CoreModel):
21-
name: str
21+
name: Optional[str]
22+
id: Optional[UUID] = None
2223

2324

2425
class GetFleetPlanRequest(CoreModel):

src/dstack/_internal/server/services/fleets.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,42 @@ async def list_project_fleet_models(
179179
return list(res.unique().scalars().all())
180180

181181

182-
async def get_fleet_by_name(
183-
session: AsyncSession, project: ProjectModel, name: str
182+
async def get_fleet(
183+
session: AsyncSession,
184+
project: ProjectModel,
185+
name: Optional[str],
186+
fleet_id: Optional[uuid.UUID],
184187
) -> Optional[Fleet]:
185-
fleet_model = await get_project_fleet_model_by_name(
186-
session=session, project=project, name=name
187-
)
188+
if fleet_id is not None:
189+
fleet_model = await get_project_fleet_model_by_id(
190+
session=session, project=project, fleet_id=fleet_id
191+
)
192+
elif name is not None:
193+
fleet_model = await get_project_fleet_model_by_name(
194+
session=session, project=project, name=name
195+
)
196+
else:
197+
raise ServerClientError("name or id must be specified")
188198
if fleet_model is None:
189199
return None
190200
return fleet_model_to_fleet(fleet_model)
191201

192202

203+
async def get_project_fleet_model_by_id(
204+
session: AsyncSession,
205+
project: ProjectModel,
206+
fleet_id: uuid.UUID,
207+
) -> Optional[FleetModel]:
208+
filters = [
209+
FleetModel.id == fleet_id,
210+
FleetModel.project_id == project.id,
211+
]
212+
res = await session.execute(
213+
select(FleetModel).where(*filters).options(joinedload(FleetModel.instances))
214+
)
215+
return res.unique().scalar_one_or_none()
216+
217+
193218
async def get_project_fleet_model_by_name(
194219
session: AsyncSession,
195220
project: ProjectModel,

src/dstack/_internal/server/testing/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ async def create_fleet(
418418
spec: Optional[FleetSpec] = None,
419419
fleet_id: Optional[UUID] = None,
420420
status: FleetStatus = FleetStatus.ACTIVE,
421+
deleted: bool = False,
421422
) -> FleetModel:
422423
if fleet_id is None:
423424
fleet_id = uuid.uuid4()
@@ -426,6 +427,7 @@ async def create_fleet(
426427
fm = FleetModel(
427428
id=fleet_id,
428429
project=project,
430+
deleted=deleted,
429431
name=spec.configuration.name,
430432
status=status,
431433
created_at=created_at,

src/dstack/api/server/_fleets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def list(self, project_name: str) -> List[Fleet]:
2020

2121
def get(self, project_name: str, name: str) -> Fleet:
2222
body = GetFleetRequest(name=name)
23-
resp = self._request(f"/api/project/{project_name}/fleets/get", body=body.json())
23+
resp = self._request(
24+
f"/api/project/{project_name}/fleets/get",
25+
body=body.json(exclude={"id"}), # `id` is not supported in pre-0.18.36 servers
26+
)
2427
return parse_obj_as(Fleet.__response__, resp.json())
2528

2629
def get_plan(

src/tests/_internal/server/routers/test_fleets.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from datetime import datetime, timezone
33
from unittest.mock import Mock, patch
4-
from uuid import UUID
4+
from uuid import UUID, uuid4
55

66
import pytest
77
from freezegun import freeze_time
@@ -183,7 +183,10 @@ async def test_returns_40x_if_not_authenticated(
183183

184184
@pytest.mark.asyncio
185185
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
186-
async def test_returns_fleet(self, test_db, session: AsyncSession, client: AsyncClient):
186+
@pytest.mark.parametrize("deleted", [False, True])
187+
async def test_returns_fleet_by_id(
188+
self, test_db, session: AsyncSession, client: AsyncClient, deleted: bool
189+
):
187190
user = await create_user(session, global_role=GlobalRole.USER)
188191
project = await create_project(session)
189192
await add_project_member(
@@ -193,11 +196,12 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async
193196
session=session,
194197
project=project,
195198
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
199+
deleted=deleted,
196200
)
197201
response = await client.post(
198202
f"/api/project/{project.name}/fleets/get",
199203
headers=get_auth_headers(user.token),
200-
json={"name": fleet.name},
204+
json={"id": str(fleet.id)},
201205
)
202206
assert response.status_code == 200
203207
assert response.json() == {
@@ -213,7 +217,67 @@ async def test_returns_fleet(self, test_db, session: AsyncSession, client: Async
213217

214218
@pytest.mark.asyncio
215219
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
216-
async def test_returns_400_if_fleet_does_not_exist(
220+
async def test_returns_not_deleted_fleet_by_name(
221+
self, test_db, session: AsyncSession, client: AsyncClient
222+
):
223+
user = await create_user(session, global_role=GlobalRole.USER)
224+
project = await create_project(session)
225+
await add_project_member(
226+
session=session, project=project, user=user, project_role=ProjectRole.USER
227+
)
228+
active_fleet = await create_fleet(
229+
session=session,
230+
project=project,
231+
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
232+
fleet_id=uuid4(),
233+
)
234+
deleted_fleet = await create_fleet(
235+
session=session,
236+
project=project,
237+
created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
238+
fleet_id=uuid4(),
239+
deleted=True,
240+
)
241+
assert active_fleet.name == deleted_fleet.name
242+
assert active_fleet.id != deleted_fleet.id
243+
response = await client.post(
244+
f"/api/project/{project.name}/fleets/get",
245+
headers=get_auth_headers(user.token),
246+
json={"name": active_fleet.name},
247+
)
248+
assert response.status_code == 200
249+
assert response.json() == {
250+
"id": str(active_fleet.id),
251+
"name": active_fleet.name,
252+
"project_name": project.name,
253+
"spec": json.loads(active_fleet.spec),
254+
"created_at": "2023-01-02T03:04:00+00:00",
255+
"status": active_fleet.status.value,
256+
"status_message": None,
257+
"instances": [],
258+
}
259+
260+
@pytest.mark.asyncio
261+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
262+
async def test_not_returns_by_name_if_fleet_deleted(
263+
self, test_db, session: AsyncSession, client: AsyncClient
264+
):
265+
user = await create_user(session, global_role=GlobalRole.USER)
266+
project = await create_project(session)
267+
await add_project_member(
268+
session=session, project=project, user=user, project_role=ProjectRole.USER
269+
)
270+
fleet = await create_fleet(session=session, project=project, deleted=True)
271+
response = await client.post(
272+
f"/api/project/{project.name}/fleets/get",
273+
headers=get_auth_headers(user.token),
274+
json={"name": fleet.name},
275+
)
276+
assert response.status_code == 400
277+
278+
@pytest.mark.asyncio
279+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
280+
async def test_not_returns_by_name_if_fleet_does_not_exist(
217281
self, test_db, session: AsyncSession, client: AsyncClient
218282
):
219283
user = await create_user(session, global_role=GlobalRole.USER)

0 commit comments

Comments
 (0)