Skip to content

Commit 1bbf543

Browse files
frascuchonjfcalvo
andauthored
[FEATURE] argilla server: Add annotators progress endpoint (#5367)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR adds a new endpoint to compute the dataset annotators' progress. The endpoint is defined as ```json GET /api/v1/datasets/:dataset_id/progress/annotators { "annotators": [ { "username": "userA", "completed": { "submitted": 20, "draft": 5, "discarded": 1 }, "pending": { "submitted": 2, "draft": 50, "discarded": 10 } }, ... ] } ``` Progress distribution is split by record status, so values represent user responses for each kind of record status. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - New feature (non-breaking change which adds functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo <[email protected]>
1 parent 08d6ed5 commit 1bbf543

File tree

6 files changed

+197
-11
lines changed

6 files changed

+197
-11
lines changed

argilla-server/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ These are the section headers that we use:
1616

1717
## [Unreleased]()
1818

19+
### Added
20+
21+
- Added new endpoint `GET /api/v1/datsets/:dataset_id/users/progress` to compute the users progress. ([#5367](https://github.com/argilla-io/argilla/pull/5367))
22+
1923
## [2.0.0](https://github.com/argilla-io/argilla/compare/v2.0.0rc1...v2.0.0)
2024

2125
> [!IMPORTANT]

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from argilla_server.api.policies.v1 import DatasetPolicy, MetadataPropertyPolicy, authorize, is_authorized
2323
from argilla_server.api.schemas.v1.datasets import (
2424
Dataset as DatasetSchema,
25+
UsersProgress,
2526
)
2627
from argilla_server.api.schemas.v1.datasets import (
2728
DatasetCreate,
@@ -39,7 +40,6 @@
3940
from argilla_server.api.schemas.v1.vector_settings import VectorSettings, VectorSettingsCreate, VectorsSettings
4041
from argilla_server.contexts import datasets
4142
from argilla_server.database import get_async_db
42-
from argilla_server.enums import ResponseStatus
4343
from argilla_server.models import Dataset, User
4444
from argilla_server.search_engine import (
4545
SearchEngine,
@@ -161,7 +161,23 @@ async def get_dataset_progress(
161161

162162
await authorize(current_user, DatasetPolicy.get(dataset))
163163

164-
return await datasets.get_dataset_progress(db, dataset_id)
164+
return await datasets.get_dataset_progress(db, dataset.id)
165+
166+
167+
@router.get("/datasets/{dataset_id}/users/progress", response_model=UsersProgress, response_model_exclude_unset=True)
168+
async def get_dataset_users_progress(
169+
*,
170+
current_user: User = Security(auth.get_current_user),
171+
dataset_id: UUID,
172+
db: AsyncSession = Depends(get_async_db),
173+
):
174+
dataset = await Dataset.get_or_raise(db, dataset_id)
175+
176+
await authorize(current_user, DatasetPolicy.get(dataset))
177+
178+
progress = await datasets.get_dataset_users_progress(dataset.id)
179+
180+
return UsersProgress(users=progress)
165181

166182

167183
@router.post("/datasets", status_code=status.HTTP_201_CREATED, response_model=DatasetSchema)

argilla-server/src/argilla_server/api/schemas/v1/datasets.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131
DATASET_GUIDELINES_MIN_LENGTH = 1
3232
DATASET_GUIDELINES_MAX_LENGTH = 10000
3333

34-
3534
DatasetName = Annotated[
3635
constr(regex=DATASET_NAME_REGEX, min_length=DATASET_NAME_MIN_LENGTH, max_length=DATASET_NAME_MAX_LENGTH),
3736
Field(..., description="Dataset name"),
3837
]
3938

40-
4139
DatasetGuidelines = Annotated[
4240
constr(min_length=DATASET_GUIDELINES_MIN_LENGTH, max_length=DATASET_GUIDELINES_MAX_LENGTH),
4341
Field(..., description="Dataset guidelines"),
@@ -88,6 +86,22 @@ class DatasetProgress(BaseModel):
8886
pending: int
8987

9088

89+
class RecordResponseDistribution(BaseModel):
90+
submitted: Optional[int]
91+
discarded: Optional[int]
92+
draft: Optional[int]
93+
94+
95+
class UserProgress(BaseModel):
96+
username: str
97+
completed: RecordResponseDistribution
98+
pending: RecordResponseDistribution
99+
100+
101+
class UsersProgress(BaseModel):
102+
users: List[UserProgress]
103+
104+
91105
class Dataset(BaseModel):
92106
id: UUID
93107
name: str

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import asyncio
1616
import copy
17+
from collections import defaultdict
18+
1719
import sqlalchemy
1820

1921
from datetime import datetime
@@ -59,6 +61,7 @@
5961
)
6062
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
6163
from argilla_server.contexts import accounts, distribution
64+
from argilla_server.database import get_async_db
6265
from argilla_server.enums import DatasetStatus, UserRole, RecordStatus
6366
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
6467
from argilla_server.models import (
@@ -425,15 +428,19 @@ async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id:
425428
async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> dict:
426429
records_completed, records_pending = await asyncio.gather(
427430
db.execute(
428-
select(func.count(Record.id)).filter(
429-
Record.dataset_id == dataset_id,
430-
Record.status == RecordStatus.completed,
431+
select(func.count(Record.id)).where(
432+
and_(
433+
Record.dataset_id == dataset_id,
434+
Record.status == RecordStatus.completed,
435+
)
431436
),
432437
),
433438
db.execute(
434-
select(func.count(Record.id)).filter(
435-
Record.dataset_id == dataset_id,
436-
Record.status == RecordStatus.pending,
439+
select(func.count(Record.id)).where(
440+
and_(
441+
Record.dataset_id == dataset_id,
442+
Record.status == RecordStatus.pending,
443+
)
437444
),
438445
),
439446
)
@@ -449,6 +456,25 @@ async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> dict:
449456
}
450457

451458

459+
async def get_dataset_users_progress(dataset_id: UUID) -> List[dict]:
460+
query = (
461+
select(User.username, Record.status, Response.status, func.count(Response.id))
462+
.join(Record)
463+
.join(User)
464+
.where(Record.dataset_id == dataset_id)
465+
.group_by(User.username, Record.status, Response.status)
466+
)
467+
468+
async for session in get_async_db():
469+
annotators_progress = defaultdict(lambda: defaultdict(dict))
470+
results = (await session.execute(query)).all()
471+
472+
for username, record_status, response_status, count in results:
473+
annotators_progress[username][record_status][response_status] = count
474+
475+
return [{"username": username, **progress} for username, progress in annotators_progress.items()]
476+
477+
452478
_EXTRA_METADATA_FLAG = "extra"
453479

454480

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2021-present, the Recognai S.L. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from uuid import UUID, uuid4
16+
17+
import pytest
18+
from httpx import AsyncClient
19+
20+
from argilla_server.constants import API_KEY_HEADER_NAME
21+
from argilla_server.enums import RecordStatus, UserRole, ResponseStatus
22+
from tests.factories import DatasetFactory, RecordFactory, AnnotatorFactory, ResponseFactory, UserFactory
23+
24+
25+
@pytest.mark.asyncio
26+
class TestGetDatasetUsersProgress:
27+
def url(self, dataset_id: UUID) -> str:
28+
return f"/api/v1/datasets/{dataset_id}/users/progress"
29+
30+
async def test_get_dataset_users_progress(self, async_client: AsyncClient, owner_auth_header: dict):
31+
dataset = await DatasetFactory.create()
32+
33+
user_with_submitted = await AnnotatorFactory.create()
34+
user_with_draft = await AnnotatorFactory.create()
35+
user_with_discarded = await AnnotatorFactory.create()
36+
37+
records_completed = await RecordFactory.create_batch(3, status=RecordStatus.completed, dataset=dataset)
38+
records_pending = await RecordFactory.create_batch(2, status=RecordStatus.pending, dataset=dataset)
39+
40+
for record in records_completed + records_pending:
41+
await ResponseFactory.create(record=record, user=user_with_submitted, status=ResponseStatus.submitted)
42+
await ResponseFactory.create(record=record, user=user_with_draft, status=ResponseStatus.draft)
43+
await ResponseFactory.create(record=record, user=user_with_discarded, status=ResponseStatus.discarded)
44+
45+
response = await async_client.get(self.url(dataset.id), headers=owner_auth_header)
46+
47+
assert response.status_code == 200, response.json()
48+
assert response.json() == {
49+
"users": [
50+
{
51+
"username": user_with_submitted.username,
52+
"completed": {"submitted": 3},
53+
"pending": {"submitted": 2},
54+
},
55+
{
56+
"username": user_with_draft.username,
57+
"completed": {"draft": 3},
58+
"pending": {"draft": 2},
59+
},
60+
{
61+
"username": user_with_discarded.username,
62+
"completed": {"discarded": 3},
63+
"pending": {"discarded": 2},
64+
},
65+
]
66+
}
67+
68+
async def test_get_dataset_users_progress_with_empty_dataset(
69+
self, async_client: AsyncClient, owner_auth_header: dict
70+
):
71+
dataset = await DatasetFactory.create()
72+
73+
response = await async_client.get(self.url(dataset.id), headers=owner_auth_header)
74+
75+
assert response.status_code == 200
76+
assert response.json() == {"users": []}
77+
78+
@pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator])
79+
async def test_get_dataset_users_progress_as_restricted_user(self, async_client: AsyncClient, user_role: UserRole):
80+
dataset = await DatasetFactory.create()
81+
user = await UserFactory.create(workspaces=[dataset.workspace], role=user_role)
82+
83+
response = await async_client.get(self.url(dataset.id), headers={API_KEY_HEADER_NAME: user.api_key})
84+
85+
assert response.status_code == 200
86+
87+
@pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator])
88+
async def test_get_dataset_users_progress_as_restricted_user_from_different_workspace(
89+
self, async_client: AsyncClient, user_role: UserRole
90+
):
91+
dataset = await DatasetFactory.create()
92+
93+
other_dataset = await DatasetFactory.create()
94+
user = await UserFactory.create(workspaces=[other_dataset.workspace], role=user_role)
95+
96+
response = await async_client.get(self.url(dataset.id), headers={API_KEY_HEADER_NAME: user.api_key})
97+
98+
assert response.status_code == 403
99+
assert response.json() == {
100+
"detail": {
101+
"code": "argilla.api.errors::ForbiddenOperationError",
102+
"params": {"detail": "Operation not allowed"},
103+
},
104+
}
105+
106+
async def test_get_dataset_users_progress_without_authentication(self, async_client: AsyncClient):
107+
response = await async_client.get(self.url(uuid4()))
108+
109+
assert response.status_code == 401
110+
assert response.json() == {
111+
"detail": {
112+
"code": "argilla.api.errors::UnauthorizedError",
113+
"params": {"detail": "Could not validate credentials"},
114+
},
115+
}
116+
117+
async def test_get_dataset_users_progress_with_nonexistent_dataset_id(
118+
self, async_client: AsyncClient, owner_auth_header: dict
119+
):
120+
dataset_id = uuid4()
121+
122+
response = await async_client.get(self.url(dataset_id), headers=owner_auth_header)
123+
124+
assert response.status_code == 404
125+
assert response.json() == {"detail": f"Dataset with id `{dataset_id}` not found"}

argilla-server/tests/unit/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from opensearchpy import OpenSearch
2323

2424
from argilla_server import telemetry
25-
from argilla_server.contexts import distribution
25+
from argilla_server.contexts import distribution, datasets
2626
from argilla_server.api.routes import api_v1
2727
from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY
2828
from argilla_server.database import get_async_db
@@ -92,6 +92,7 @@ async def override_get_search_engine():
9292
yield mock_search_engine
9393

9494
mocker.patch.object(distribution, "_get_async_db", override_get_async_db)
95+
mocker.patch.object(datasets, "get_async_db", override_get_async_db)
9596

9697
api_v1.dependency_overrides.update(
9798
{

0 commit comments

Comments
 (0)