Skip to content

Commit 9bc6ff3

Browse files
authored
perf: Using search engine to compute the total number of records for user metrics (#5641)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> Using search engine total instead of DB count to get the total number of records. This improves the performance when running the Argilla server on an HF space with persistent storage enabled. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 432f557 commit 9bc6ff3

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ async def get_current_user_dataset_metrics(
152152

153153
await authorize(current_user, DatasetPolicy.get(dataset))
154154

155-
result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)
155+
result = await datasets.get_user_dataset_metrics(search_engine, current_user, dataset)
156156

157157
return DatasetMetrics(responses=result)
158158

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,11 @@ async def _configure_query_relationships(
363363

364364

365365
async def get_user_dataset_metrics(
366-
db: AsyncSession,
367366
search_engine: SearchEngine,
368367
user: User,
369368
dataset: Dataset,
370369
) -> dict:
371-
total_records = await Record.count_by(db, dataset_id=dataset.id)
370+
total_records = (await get_dataset_progress(search_engine, dataset))["total"]
372371
result = await search_engine.get_dataset_user_progress(dataset, user)
373372

374373
submitted_responses = result.get("submitted", 0)

argilla-server/tests/unit/api/handlers/v1/test_datasets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,8 @@ async def test_get_current_user_dataset_metrics(
742742
dataset = await DatasetFactory.create()
743743
records = await RecordFactory.create_batch(size=8, dataset=dataset)
744744

745+
mock_search_engine.get_dataset_progress.return_value = {"total": len(records)}
746+
745747
mock_search_engine.get_dataset_user_progress.return_value = {
746748
"total": 6,
747749
"submitted": 3,
@@ -772,6 +774,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
772774
):
773775
dataset = await DatasetFactory.create()
774776

777+
mock_search_engine.get_dataset_progress.return_value = {}
775778
mock_search_engine.get_dataset_user_progress.return_value = {}
776779

777780
response = await async_client.get(
@@ -791,7 +794,7 @@ async def test_get_current_user_dataset_metrics_with_empty_dataset(
791794
}
792795

793796
@pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin])
794-
async def test_get_current_user_dataset_metrics_as_annotator(
797+
async def test_get_current_user_dataset_metrics_as_different_role(
795798
self,
796799
async_client: "AsyncClient",
797800
mock_search_engine: SearchEngine,
@@ -800,10 +803,13 @@ async def test_get_current_user_dataset_metrics_as_annotator(
800803
dataset = await DatasetFactory.create()
801804
records = await RecordFactory.create_batch(size=6, dataset=dataset)
802805

803-
user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role)
806+
user = await UserFactory.create(workspaces=[dataset.workspace], role=role)
804807

805-
mock_search_engine.get_dataset_user_progress.return_value = {
808+
mock_search_engine.get_dataset_progress.return_value = {
806809
"total": len(records),
810+
}
811+
812+
mock_search_engine.get_dataset_user_progress.return_value = {
807813
"submitted": 2,
808814
"discarded": 1,
809815
"draft": 1,

0 commit comments

Comments
 (0)