Skip to content

Commit 6129ce8

Browse files
frascuchonjfcalvo
andauthored
[ENHANCEMENT] argilla server: Return users on dataset progress (#5701)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR adds support to return a list of usernames in the dataset progress endpoint. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: José Francisco Calvo <[email protected]>
1 parent ab6b2f0 commit 6129ce8

File tree

15 files changed

+306
-61
lines changed

15 files changed

+306
-61
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2021-present, the Recognai S.L. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""add datasets_users table
16+
17+
Revision ID: 580a6553186f
18+
Revises: 6ed1b8bf8e08
19+
Create Date: 2024-11-20 12:15:24.631417
20+
21+
"""
22+
23+
import sqlalchemy as sa
24+
from alembic import op
25+
26+
# revision identifiers, used by Alembic.
27+
revision = "580a6553186f"
28+
down_revision = "6ed1b8bf8e08"
29+
branch_labels = None
30+
depends_on = None
31+
32+
33+
def upgrade() -> None:
34+
op.create_table(
35+
"datasets_users",
36+
sa.Column("dataset_id", sa.Uuid(), nullable=False),
37+
sa.Column("user_id", sa.Uuid(), nullable=False),
38+
sa.Column("inserted_at", sa.DateTime(), nullable=False),
39+
sa.Column("updated_at", sa.DateTime(), nullable=False),
40+
sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"], ondelete="CASCADE"),
41+
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
42+
sa.PrimaryKeyConstraint("dataset_id", "user_id"),
43+
)
44+
op.create_index(op.f("ix_datasets_users_dataset_id"), "datasets_users", ["dataset_id"], unique=False)
45+
op.create_index(op.f("ix_datasets_users_user_id"), "datasets_users", ["user_id"], unique=False)
46+
47+
bind = op.get_bind()
48+
49+
statement = """
50+
INSERT INTO datasets_users (dataset_id, user_id, inserted_at, updated_at)
51+
SELECT dataset_id, user_id, {now_func}, {now_func} FROM (
52+
SELECT DISTINCT records.dataset_id AS dataset_id, responses.user_id as user_id
53+
FROM responses
54+
JOIN records ON records.id = responses.record_id
55+
) AS subquery
56+
"""
57+
58+
if bind.dialect.name == "postgresql":
59+
op.execute(statement.format(now_func="NOW()"))
60+
elif bind.dialect.name == "sqlite":
61+
op.execute(statement.format(now_func="datetime('now')"))
62+
else:
63+
raise Exception("Unsupported database dialect")
64+
65+
66+
def downgrade() -> None:
67+
op.drop_index(op.f("ix_datasets_users_user_id"), table_name="datasets_users")
68+
op.drop_index(op.f("ix_datasets_users_dataset_id"), table_name="datasets_users")
69+
op.drop_table("datasets_users")

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ async def get_current_user_dataset_metrics(
154154

155155
await authorize(current_user, DatasetPolicy.get(dataset))
156156

157-
result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset)
157+
result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)
158158

159159
return DatasetMetrics(responses=result)
160160

161161

162-
@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress)
162+
@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress, response_model_exclude_unset=True)
163163
async def get_dataset_progress(
164164
*,
165165
dataset_id: UUID,
@@ -171,7 +171,7 @@ async def get_dataset_progress(
171171

172172
await authorize(current_user, DatasetPolicy.get(dataset))
173173

174-
result = await datasets.get_dataset_progress(search_engine, dataset)
174+
result = await datasets.get_dataset_progress(db, search_engine, dataset)
175175

176176
return DatasetProgress(**result)
177177

@@ -181,14 +181,13 @@ async def get_dataset_users_progress(
181181
*,
182182
dataset_id: UUID,
183183
db: AsyncSession = Depends(get_async_db),
184-
search_engine: SearchEngine = Depends(get_search_engine),
185184
current_user: User = Security(auth.get_current_user),
186185
):
187186
dataset = await Dataset.get_or_raise(db, dataset_id)
188187

189188
await authorize(current_user, DatasetPolicy.get(dataset))
190189

191-
progress = await datasets.get_dataset_users_progress(dataset.id)
190+
progress = await datasets.get_dataset_users_progress(db, dataset)
192191

193192
return UsersProgress(users=progress)
194193

argilla-server/src/argilla_server/api/schemas/v1/datasets.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,6 @@ class DatasetMetrics(BaseModel):
8484
responses: ResponseMetrics
8585

8686

87-
class DatasetProgress(BaseModel):
88-
total: int
89-
completed: int
90-
pending: int
91-
92-
9387
class RecordResponseDistribution(BaseModel):
9488
submitted: int = 0
9589
discarded: int = 0
@@ -101,6 +95,15 @@ class UserProgress(BaseModel):
10195
completed: RecordResponseDistribution = RecordResponseDistribution()
10296
pending: RecordResponseDistribution = RecordResponseDistribution()
10397

98+
model_config = ConfigDict(from_attributes=True)
99+
100+
101+
class DatasetProgress(BaseModel):
102+
total: int
103+
completed: int
104+
pending: int
105+
users: List[UserProgress] = Field(default_factory=list)
106+
104107

105108
class UsersProgress(BaseModel):
106109
users: List[UserProgress]

argilla-server/src/argilla_server/bulk/records_bulk.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from argilla_server.api.schemas.v1.responses import UserResponseCreate
3232
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
33+
from argilla_server.models.database import DatasetUser
3334
from argilla_server.webhooks.v1.enums import RecordEvent
3435
from argilla_server.webhooks.v1.records import notify_record_event as notify_record_event_v1
3536
from argilla_server.contexts import distribution
@@ -109,13 +110,22 @@ async def _upsert_records_responses(
109110
self, records_and_responses: List[Tuple[Record, List[UserResponseCreate]]]
110111
) -> List[Response]:
111112
upsert_many_responses = []
113+
datasets_users = set()
112114
for idx, (record, responses) in enumerate(records_and_responses):
113115
for response_create in responses or []:
114116
upsert_many_responses.append(dict(**response_create.model_dump(), record_id=record.id))
117+
datasets_users.add((response_create.user_id, record.dataset_id))
115118

116119
if not upsert_many_responses:
117120
return []
118121

122+
await DatasetUser.upsert_many(
123+
self._db,
124+
objects=[{"user_id": user_id, "dataset_id": dataset_id} for user_id, dataset_id in datasets_users],
125+
constraints=[DatasetUser.user_id, DatasetUser.dataset_id],
126+
autocommit=False,
127+
)
128+
119129
return await Response.upsert_many(
120130
self._db,
121131
objects=upsert_many_responses,

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
ResponseCreate,
5050
ResponseUpdate,
5151
ResponseUpsert,
52-
UserResponseCreate,
5352
)
5453
from argilla_server.api.schemas.v1.vector_settings import (
5554
VectorSettings as VectorSettingsSchema,
@@ -58,6 +57,7 @@
5857
VectorSettingsCreate,
5958
)
6059
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
60+
from argilla_server.models.database import DatasetUser
6161
from argilla_server.webhooks.v1.enums import DatasetEvent, ResponseEvent, RecordEvent
6262
from argilla_server.webhooks.v1.records import (
6363
build_record_event as build_record_event_v1,
@@ -391,11 +391,12 @@ async def _configure_query_relationships(
391391

392392

393393
async def get_user_dataset_metrics(
394+
db: AsyncSession,
394395
search_engine: SearchEngine,
395396
user: User,
396397
dataset: Dataset,
397398
) -> dict:
398-
total_records = (await get_dataset_progress(search_engine, dataset))["total"]
399+
total_records = (await get_dataset_progress(db, search_engine, dataset))["total"]
399400
result = await search_engine.get_dataset_user_progress(dataset, user)
400401

401402
submitted_responses = result.get("submitted", 0)
@@ -413,34 +414,52 @@ async def get_user_dataset_metrics(
413414

414415

415416
async def get_dataset_progress(
417+
db: AsyncSession,
416418
search_engine: SearchEngine,
417419
dataset: Dataset,
418420
) -> dict:
419421
result = await search_engine.get_dataset_progress(dataset)
422+
users = await get_users_with_responses_for_dataset(db, dataset)
423+
420424
return {
421425
"total": result.get("total", 0),
422426
"completed": result.get("completed", 0),
423427
"pending": result.get("pending", 0),
428+
"users": users,
424429
}
425430

426431

427-
async def get_dataset_users_progress(dataset_id: UUID) -> List[dict]:
432+
async def get_users_with_responses_for_dataset(
433+
db: AsyncSession,
434+
dataset: Dataset,
435+
) -> Sequence[User]:
436+
query = (
437+
select(DatasetUser)
438+
.filter_by(dataset_id=dataset.id)
439+
.options(selectinload(DatasetUser.user))
440+
.order_by(DatasetUser.inserted_at.asc())
441+
)
442+
443+
result = await db.scalars(query)
444+
return [r.user for r in result.all()]
445+
446+
447+
async def get_dataset_users_progress(db: AsyncSession, dataset: Dataset) -> List[dict]:
428448
query = (
429449
select(User.username, Record.status, Response.status, func.count(Response.id))
430450
.join(Record)
431451
.join(User)
432-
.where(Record.dataset_id == dataset_id)
452+
.where(Record.dataset_id == dataset.id)
433453
.group_by(User.username, Record.status, Response.status)
434454
)
435455

436-
async for session in get_async_db():
437-
annotators_progress = defaultdict(lambda: defaultdict(dict))
438-
results = (await session.execute(query)).all()
456+
annotators_progress = defaultdict(lambda: defaultdict(dict))
457+
results = (await db.execute(query)).all()
439458

440-
for username, record_status, response_status, count in results:
441-
annotators_progress[username][record_status][response_status] = count
459+
for username, record_status, response_status, count in results:
460+
annotators_progress[username][record_status][response_status] = count
442461

443-
return [{"username": username, **progress} for username, progress in annotators_progress.items()]
462+
return [{"username": username, **progress} for username, progress in annotators_progress.items()]
444463

445464

446465
_EXTRA_METADATA_FLAG = "extra"
@@ -567,38 +586,6 @@ async def _validate_record_metadata(
567586
raise UnprocessableEntityError(f"metadata is not valid: {e}") from e
568587

569588

570-
async def _build_record_responses(
571-
db: AsyncSession,
572-
record: Record,
573-
responses_create: Optional[List[UserResponseCreate]],
574-
cache: Optional[Set[UUID]] = None,
575-
) -> List[Response]:
576-
"""Create responses for a record."""
577-
if not responses_create:
578-
return []
579-
580-
responses = []
581-
582-
for idx, response_create in enumerate(responses_create):
583-
try:
584-
cache = await validate_user_exists(db, response_create.user_id, cache)
585-
586-
ResponseCreateValidator.validate(response_create, record)
587-
588-
responses.append(
589-
Response(
590-
values=jsonable_encoder(response_create.values),
591-
status=response_create.status,
592-
user_id=response_create.user_id,
593-
record=record,
594-
)
595-
)
596-
except (UnprocessableEntityError, ValueError) as e:
597-
raise UnprocessableEntityError(f"response at position {idx} is not valid: {e}") from e
598-
599-
return responses
600-
601-
602589
async def _build_record_suggestions(
603590
db: AsyncSession,
604591
record: Record,
@@ -830,8 +817,13 @@ async def create_response(
830817
user_id=user.id,
831818
autocommit=False,
832819
)
833-
834820
await _touch_dataset_last_activity_at(db, record.dataset)
821+
await DatasetUser.upsert(
822+
db,
823+
schema={"dataset_id": record.dataset_id, "user_id": user.id},
824+
constraints=[DatasetUser.dataset_id, DatasetUser.user_id],
825+
autocommit=False,
826+
)
835827

836828
await db.commit()
837829

@@ -888,7 +880,12 @@ async def upsert_response(
888880
autocommit=False,
889881
)
890882
await _touch_dataset_last_activity_at(db, response.record.dataset)
891-
883+
await DatasetUser.upsert(
884+
db,
885+
schema={"dataset_id": record.dataset_id, "user_id": user.id},
886+
constraints=[DatasetUser.dataset_id, DatasetUser.user_id],
887+
autocommit=False,
888+
)
892889
await db.commit()
893890

894891
await distribution.update_record_status(search_engine, record.id)

argilla-server/src/argilla_server/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"2.0": "237f7c674d74",
3939
"2.4": "660d6c6b3360",
4040
"2.5": "6ed1b8bf8e08",
41+
"2.6": "580a6553186f",
4142
}
4243
)
4344

0 commit comments

Comments
 (0)