Skip to content

Commit ac8b9fc

Browse files
committed
[DOP-22530] Implement GET /v1/datasets?location_type=...
1 parent 00e2ea6 commit ac8b9fc

File tree

6 files changed

+192
-11
lines changed

6 files changed

+192
-11
lines changed

data_rentgen/db/repositories/dataset.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,18 @@ async def paginate(
9999
dataset_ids: Collection[int],
100100
tag_value_ids: Collection[int],
101101
location_id: int | None,
102+
location_type: Collection[str],
102103
search_query: str | None,
103104
) -> PaginationDTO[Dataset]:
104105
where = []
106+
location_join_clause = Location.id == Dataset.location_id
105107
if dataset_ids:
106108
where.append(Dataset.id == any_(list(dataset_ids))) # type: ignore[arg-type]
107-
108109
if location_id:
109110
where.append(Dataset.location_id == location_id)
111+
if location_type:
112+
location_type_lower = [location_type.lower() for location_type in location_type]
113+
where.append(Location.type == any_(location_type_lower)) # type: ignore[arg-type]
110114

111115
if tag_value_ids:
112116
tv_ids = list(tag_value_ids)
@@ -125,19 +129,20 @@ async def paginate(
125129
if search_query:
126130
tsquery = make_tsquery(search_query)
127131

128-
dataset_stmt = select(Dataset, ts_rank(Dataset.search_vector, tsquery).label("search_rank")).where(
129-
ts_match(Dataset.search_vector, tsquery),
130-
*where,
132+
dataset_stmt = (
133+
select(Dataset, ts_rank(Dataset.search_vector, tsquery).label("search_rank"))
134+
.join(Location, location_join_clause)
135+
.where(ts_match(Dataset.search_vector, tsquery), *where)
131136
)
132137
location_stmt = (
133138
select(Dataset, ts_rank(Location.search_vector, tsquery).label("search_rank"))
134-
.join(Dataset, Location.id == Dataset.location_id)
139+
.join(Location, location_join_clause)
135140
.where(ts_match(Location.search_vector, tsquery), *where)
136141
)
137142
address_stmt = (
138143
select(Dataset, func.max(ts_rank(Address.search_vector, tsquery)).label("search_rank"))
139-
.join(Location, Address.location_id == Location.id)
140-
.join(Dataset, Location.id == Dataset.location_id)
144+
.join(Location, location_join_clause)
145+
.join(Address, Address.location_id == Dataset.location_id)
141146
.where(ts_match(Address.search_vector, tsquery), *where)
142147
.group_by(Dataset.id, Location.id, Address.id)
143148
)
@@ -152,7 +157,7 @@ async def paginate(
152157
).group_by(*dataset_columns)
153158
order_by = [desc("search_rank"), asc("name")]
154159
else:
155-
query = select(Dataset).where(*where)
160+
query = select(Dataset).join(Location, location_join_clause).where(*where)
156161
order_by = [Dataset.name]
157162

158163
options = [

data_rentgen/server/api/v1/router/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ async def paginate_datasets(
3939
dataset_ids=query_args.dataset_id,
4040
tag_value_ids=query_args.tag_value_id,
4141
location_id=query_args.location_id,
42+
location_type=query_args.location_type,
4243
search_query=query_args.search_query,
4344
)
4445
return PageResponseV1[DatasetDetailedResponseV1].from_pagination(pagination)

data_rentgen/server/schemas/v1/dataset.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,28 @@ class DatasetDetailedResponseV1(BaseModel):
5555
class DatasetPaginateQueryV1(PaginateQueryV1):
5656
"""Query params for Dataset paginate request."""
5757

58-
dataset_id: list[int] = Field(default_factory=list, description="Dataset id")
59-
tag_value_id: list[int] = Field(default_factory=list, description="Tag value id")
60-
location_id: int | None = Field(default=None, description="Location id to filter dataset")
58+
dataset_id: list[int] = Field(
59+
default_factory=list,
60+
description="Get specific datasets by their ids",
61+
)
62+
tag_value_id: list[int] = Field(
63+
default_factory=list,
64+
description="Get datasets with specific tag values (AND)",
65+
)
66+
location_id: int | None = Field(
67+
default=None,
68+
description="Get datasets by location id",
69+
)
70+
location_type: list[str] = Field(
71+
default_factory=list,
72+
description="Get datasets by location types",
73+
examples=[["yarn"]],
74+
)
6175
search_query: str | None = Field(
6276
default=None,
6377
min_length=3,
6478
description="Search query",
79+
examples=[["my dataset"]],
6580
)
6681

6782
model_config = ConfigDict(extra="forbid")

data_rentgen/server/services/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ async def paginate(
4343
dataset_ids: Collection[int],
4444
tag_value_ids: Collection[int],
4545
location_id: int | None,
46+
location_type: Collection[str],
4647
search_query: str | None,
4748
) -> DatasetServicePaginatedResult:
4849
pagination = await self._uow.dataset.paginate(
@@ -51,6 +52,7 @@ async def paginate(
5152
dataset_ids=dataset_ids,
5253
tag_value_ids=tag_value_ids,
5354
location_id=location_id,
55+
location_type=location_type,
5456
search_query=search_query,
5557
)
5658

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add new filter to ``GET /v1/datasets``:
2+
- location_type: ``list[str]``
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from http import HTTPStatus
2+
3+
import pytest
4+
from httpx import AsyncClient
5+
from sqlalchemy import select
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
8+
from data_rentgen.db.models import Dataset, Location
9+
from tests.fixtures.mocks import MockedUser
10+
from tests.test_server.utils.convert_to_json import dataset_to_json, tag_values_to_json
11+
from tests.test_server.utils.enrich import enrich_datasets
12+
13+
pytestmark = [pytest.mark.server, pytest.mark.asyncio]
14+
15+
16+
async def test_get_datasets_by_location_id(
17+
test_client: AsyncClient,
18+
async_session: AsyncSession,
19+
datasets_search: tuple[dict[str, Dataset], ...],
20+
mocked_user: MockedUser,
21+
) -> None:
22+
_, _, datasets_by_address = datasets_search
23+
datasets = await enrich_datasets([datasets_by_address["hdfs://my-cluster-namenode:2080"]], async_session)
24+
location_id = datasets[0].location_id
25+
26+
response = await test_client.get(
27+
"/v1/datasets",
28+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
29+
params={"location_id": location_id},
30+
)
31+
32+
assert response.status_code == HTTPStatus.OK, response.json()
33+
assert response.json() == {
34+
"meta": {
35+
"has_next": False,
36+
"has_previous": False,
37+
"next_page": None,
38+
"page": 1,
39+
"page_size": 20,
40+
"pages_count": 1,
41+
"previous_page": None,
42+
"total_count": 1,
43+
},
44+
"items": [
45+
{
46+
"id": str(dataset.id),
47+
"data": dataset_to_json(dataset),
48+
"tags": [],
49+
}
50+
for dataset in datasets
51+
],
52+
}
53+
54+
55+
async def test_get_datasets_by_location_id_non_existent(
56+
test_client: AsyncClient,
57+
async_session: AsyncSession,
58+
datasets_search: tuple[dict[str, Dataset], ...],
59+
mocked_user: MockedUser,
60+
) -> None:
61+
response = await test_client.get(
62+
"/v1/datasets",
63+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
64+
params={"location_id": -1},
65+
)
66+
67+
assert response.status_code == HTTPStatus.OK, response.json()
68+
assert response.json() == {
69+
"meta": {
70+
"has_next": False,
71+
"has_previous": False,
72+
"next_page": None,
73+
"page": 1,
74+
"page_size": 20,
75+
"pages_count": 1,
76+
"previous_page": None,
77+
"total_count": 0,
78+
},
79+
"items": [],
80+
}
81+
82+
83+
async def test_get_datasets_by_location_type(
84+
test_client: AsyncClient,
85+
async_session: AsyncSession,
86+
datasets_search: tuple[dict[str, Dataset], ...],
87+
mocked_user: MockedUser,
88+
) -> None:
89+
# random locations created by datasets_search fixture can also have type=hdfs
90+
datasets_query = (
91+
select(Dataset)
92+
.join(Location, Location.id == Dataset.location_id)
93+
.where(Location.type == "hdfs")
94+
.order_by(Dataset.name)
95+
)
96+
97+
dataset_scalars = await async_session.scalars(datasets_query)
98+
async_session.expunge_all()
99+
100+
datasets = await enrich_datasets(list(dataset_scalars.all()), async_session)
101+
102+
response = await test_client.get(
103+
"/v1/datasets",
104+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
105+
params={"location_type": ["HDFS"]}, # case-insensitive
106+
)
107+
108+
assert response.status_code == HTTPStatus.OK, response.json()
109+
assert response.json() == {
110+
"meta": {
111+
"has_next": False,
112+
"has_previous": False,
113+
"next_page": None,
114+
"page": 1,
115+
"page_size": 20,
116+
"pages_count": 1,
117+
"previous_page": None,
118+
"total_count": len(datasets),
119+
},
120+
"items": [
121+
{
122+
"id": str(dataset.id),
123+
"data": dataset_to_json(dataset),
124+
"tags": [],
125+
}
126+
for dataset in datasets
127+
],
128+
}
129+
130+
131+
async def test_get_datasets_by_location_type_non_existent(
132+
test_client: AsyncClient,
133+
async_session: AsyncSession,
134+
datasets_search: tuple[dict[str, Dataset], ...],
135+
mocked_user: MockedUser,
136+
) -> None:
137+
response = await test_client.get(
138+
"/v1/datasets",
139+
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
140+
params={"location_type": "non_existent_location_type"},
141+
)
142+
143+
assert response.status_code == HTTPStatus.OK, response.json()
144+
assert response.json() == {
145+
"meta": {
146+
"has_next": False,
147+
"has_previous": False,
148+
"next_page": None,
149+
"page": 1,
150+
"page_size": 20,
151+
"pages_count": 1,
152+
"previous_page": None,
153+
"total_count": 0,
154+
},
155+
"items": [],
156+
}

0 commit comments

Comments
 (0)