Skip to content

Commit 66454e8

Browse files
authored
[PERF][IMPROVEMENT] argilla server: improve computation for dataset progress and metrics (#5618)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> Computing dataset progress and metrics using the search engine drastically reduces the time required to run the Argilla server in HF spaces when persistent storage is enabled. ## TODO - [x] Add search engine tests - [x] Adapt existing test to mock engine results **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent eb6741e commit 66454e8

File tree

8 files changed

+228
-153
lines changed

8 files changed

+228
-153
lines changed

argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,37 +143,44 @@ async def get_dataset(
143143
@router.get("/me/datasets/{dataset_id}/metrics", response_model=DatasetMetrics)
144144
async def get_current_user_dataset_metrics(
145145
*,
146-
db: AsyncSession = Depends(get_async_db),
147146
dataset_id: UUID,
147+
db: AsyncSession = Depends(get_async_db),
148+
search_engine: SearchEngine = Depends(get_search_engine),
148149
current_user: User = Security(auth.get_current_user),
149150
):
150151
dataset = await Dataset.get_or_raise(db, dataset_id)
151152

152153
await authorize(current_user, DatasetPolicy.get(dataset))
153154

154-
return await datasets.get_user_dataset_metrics(db, current_user.id, dataset.id)
155+
result = await datasets.get_user_dataset_metrics(db, search_engine, current_user, dataset)
156+
157+
return DatasetMetrics(responses=result)
155158

156159

157160
@router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress)
158161
async def get_dataset_progress(
159162
*,
160-
db: AsyncSession = Depends(get_async_db),
161163
dataset_id: UUID,
164+
db: AsyncSession = Depends(get_async_db),
165+
search_engine: SearchEngine = Depends(get_search_engine),
162166
current_user: User = Security(auth.get_current_user),
163167
):
164168
dataset = await Dataset.get_or_raise(db, dataset_id)
165169

166170
await authorize(current_user, DatasetPolicy.get(dataset))
167171

168-
return await datasets.get_dataset_progress(db, dataset.id)
172+
result = await datasets.get_dataset_progress(search_engine, dataset)
173+
174+
return DatasetProgress(**result)
169175

170176

171177
@router.get("/datasets/{dataset_id}/users/progress", response_model=UsersProgress)
172178
async def get_dataset_users_progress(
173179
*,
174-
current_user: User = Security(auth.get_current_user),
175180
dataset_id: UUID,
176181
db: AsyncSession = Depends(get_async_db),
182+
search_engine: SearchEngine = Depends(get_search_engine),
183+
current_user: User = Security(auth.get_current_user),
177184
):
178185
dataset = await Dataset.get_or_raise(db, dataset_id)
179186

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 28 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
1615
import copy
1716
from collections import defaultdict
18-
19-
import sqlalchemy
20-
2117
from datetime import datetime
2218
from typing import (
2319
TYPE_CHECKING,
@@ -35,6 +31,8 @@
3531
Union,
3632
)
3733
from uuid import UUID
34+
35+
import sqlalchemy
3836
from fastapi.encoders import jsonable_encoder
3937
from sqlalchemy import Select, and_, func, select
4038
from sqlalchemy.ext.asyncio import AsyncSession
@@ -62,7 +60,7 @@
6260
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
6361
from argilla_server.contexts import accounts, distribution
6462
from argilla_server.database import get_async_db
65-
from argilla_server.enums import DatasetStatus, UserRole, RecordStatus
63+
from argilla_server.enums import DatasetStatus, UserRole
6664
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
6765
from argilla_server.jobs import dataset_jobs
6866
from argilla_server.models import (
@@ -72,7 +70,6 @@
7270
Question,
7371
Record,
7472
Response,
75-
ResponseStatus,
7673
Suggestion,
7774
User,
7875
Vector,
@@ -377,88 +374,38 @@ async def _configure_query_relationships(
377374
return query
378375

379376

380-
async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id: UUID) -> dict:
381-
responses_submitted, responses_discarded, responses_draft, responses_pending = await asyncio.gather(
382-
db.execute(
383-
select(func.count(Response.id))
384-
.join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id))
385-
.filter(
386-
Response.user_id == user_id,
387-
Response.status == ResponseStatus.submitted,
388-
),
389-
),
390-
db.execute(
391-
select(func.count(Response.id))
392-
.join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id))
393-
.filter(
394-
Response.user_id == user_id,
395-
Response.status == ResponseStatus.discarded,
396-
),
397-
),
398-
db.execute(
399-
select(func.count(Response.id))
400-
.join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id))
401-
.filter(
402-
Response.user_id == user_id,
403-
Response.status == ResponseStatus.draft,
404-
),
405-
),
406-
db.execute(
407-
select(func.count(Record.id))
408-
.outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id))
409-
.filter(
410-
Record.dataset_id == dataset_id,
411-
Record.status == RecordStatus.pending,
412-
Response.id == None, # noqa
413-
),
414-
),
415-
)
377+
async def get_user_dataset_metrics(
378+
db: AsyncSession,
379+
search_engine: SearchEngine,
380+
user: User,
381+
dataset: Dataset,
382+
) -> dict:
383+
total_records = await Record.count_by(db, dataset_id=dataset.id)
384+
result = await search_engine.get_dataset_user_progress(dataset, user)
416385

417-
responses_submitted = responses_submitted.scalar_one()
418-
responses_discarded = responses_discarded.scalar_one()
419-
responses_draft = responses_draft.scalar_one()
420-
responses_pending = responses_pending.scalar_one()
421-
responses_total = responses_submitted + responses_discarded + responses_draft + responses_pending
386+
submitted_responses = result.get("submitted", 0)
387+
discarded_responses = result.get("discarded", 0)
388+
draft_responses = result.get("draft", 0)
389+
pending_responses = total_records - submitted_responses - discarded_responses - draft_responses
422390

423391
return {
424-
"responses": {
425-
"total": responses_total,
426-
"submitted": responses_submitted,
427-
"discarded": responses_discarded,
428-
"draft": responses_draft,
429-
"pending": responses_pending,
430-
},
392+
"total": total_records,
393+
"submitted": submitted_responses,
394+
"discarded": discarded_responses,
395+
"draft": draft_responses,
396+
"pending": pending_responses,
431397
}
432398

433399

434-
async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> dict:
435-
records_completed, records_pending = await asyncio.gather(
436-
db.execute(
437-
select(func.count(Record.id)).where(
438-
and_(
439-
Record.dataset_id == dataset_id,
440-
Record.status == RecordStatus.completed,
441-
)
442-
),
443-
),
444-
db.execute(
445-
select(func.count(Record.id)).where(
446-
and_(
447-
Record.dataset_id == dataset_id,
448-
Record.status == RecordStatus.pending,
449-
)
450-
),
451-
),
452-
)
453-
454-
records_completed = records_completed.scalar_one()
455-
records_pending = records_pending.scalar_one()
456-
records_total = records_completed + records_pending
457-
400+
async def get_dataset_progress(
401+
search_engine: SearchEngine,
402+
dataset: Dataset,
403+
) -> dict:
404+
result = await search_engine.get_dataset_progress(dataset)
458405
return {
459-
"total": records_total,
460-
"completed": records_completed,
461-
"pending": records_pending,
406+
"total": result.get("total", 0),
407+
"completed": result.get("completed", 0),
408+
"pending": result.get("pending", 0),
462409
}
463410

464411

argilla-server/src/argilla_server/search_engine/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Optional,
2323
Union,
2424
TypeVar,
25+
Literal,
2526
)
2627
from uuid import UUID
2728

@@ -45,7 +46,7 @@
4546
"SearchResponses",
4647
"SortBy",
4748
"MetadataMetrics",
48-
"TermsMetadataMetrics",
49+
"TermsMetrics",
4950
"IntegerMetadataMetrics",
5051
"FloatMetadataMetrics",
5152
"SuggestionFilterScope",
@@ -149,12 +150,12 @@ class Config:
149150
arbitrary_types_allowed = True
150151

151152

152-
class TermsMetadataMetrics(BaseModel):
153+
class TermsMetrics(BaseModel):
153154
class TermCount(BaseModel):
154155
term: str
155156
count: int
156157

157-
type: MetadataPropertyType = Field(MetadataPropertyType.terms)
158+
type: Literal["terms"] = "terms"
158159
total: int
159160
values: List[TermCount] = Field(default_factory=list)
160161

@@ -175,7 +176,7 @@ class FloatMetadataMetrics(NumericMetadataMetrics[float]):
175176
type: MetadataPropertyType = Field(MetadataPropertyType.float)
176177

177178

178-
MetadataMetrics = Union[TermsMetadataMetrics, IntegerMetadataMetrics, FloatMetadataMetrics]
179+
MetadataMetrics = Union[TermsMetrics, IntegerMetadataMetrics, FloatMetadataMetrics]
179180

180181

181182
class SearchEngine(metaclass=ABCMeta):
@@ -267,6 +268,14 @@ async def update_record_suggestion(self, suggestion: Suggestion):
267268
async def delete_record_suggestion(self, suggestion: Suggestion):
268269
pass
269270

271+
@abstractmethod
272+
async def get_dataset_progress(self, dataset: Dataset) -> dict:
273+
pass
274+
275+
@abstractmethod
276+
async def get_dataset_user_progress(self, dataset: Dataset, user: User) -> dict:
277+
pass
278+
270279
@abstractmethod
271280
async def search(
272281
self,

argilla-server/src/argilla_server/search_engine/commons.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
Suggestion,
3333
Vector,
3434
VectorSettings,
35+
User,
3536
)
3637
from argilla_server.search_engine.base import (
3738
AndFilter,
@@ -50,7 +51,7 @@
5051
SearchResponses,
5152
SuggestionFilterScope,
5253
TermsFilter,
53-
TermsMetadataMetrics,
54+
TermsMetrics,
5455
TextQuery,
5556
)
5657

@@ -390,6 +391,26 @@ async def delete_record_suggestion(self, suggestion: Suggestion):
390391
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
391392
)
392393

394+
async def get_dataset_progress(self, dataset: Dataset) -> dict:
395+
index_name = es_index_name_for_dataset(dataset)
396+
397+
metrics = await self._compute_terms_metrics_for(index_name, "status")
398+
399+
return {"total": metrics.total, **{metric.term: metric.count for metric in metrics.values}}
400+
401+
async def get_dataset_user_progress(self, dataset: Dataset, user: User) -> dict:
402+
index_name = es_index_name_for_dataset(dataset)
403+
404+
result = await self._compute_terms_metrics_for(
405+
index_name,
406+
field_name=es_field_for_response_property("status"),
407+
query=es_nested_query(
408+
path="responses",
409+
query=es_term_query(es_field_for_response_property("user_id"), str(user.id)),
410+
),
411+
)
412+
return {"total": result.total, **{metric.term: metric.count for metric in result.values}}
413+
393414
async def search(
394415
self,
395416
dataset: Dataset,
@@ -598,21 +619,27 @@ async def _metrics_for_numeric_property(
598619

599620
return metrics_class(min=stats["min"], max=stats["max"])
600621

601-
async def _metrics_for_terms_property(
602-
self, index_name: str, metadata_property: MetadataProperty, query: Optional[dict] = None
603-
) -> TermsMetadataMetrics:
604-
field_name = es_field_for_metadata_property(metadata_property)
622+
async def _compute_terms_metrics_for(
623+
self, index_name: str, field_name: str, query: Optional[dict] = None
624+
) -> TermsMetrics:
605625
query = query or {"match_all": {}}
606626

607627
total_terms = await self.__value_count_aggregation(index_name, field_name=field_name, query=query)
608628
if total_terms == 0:
609-
return TermsMetadataMetrics(total=total_terms)
629+
return TermsMetrics(total=total_terms)
610630

611-
terms_buckets = await self.__terms_aggregation(index_name, field_name=field_name, query=query, size=total_terms)
631+
terms_buckets = await self._terms_aggregation(index_name, field_name=field_name, query=query, size=total_terms)
612632
terms_values = [
613-
TermsMetadataMetrics.TermCount(term=bucket["key"], count=bucket["doc_count"]) for bucket in terms_buckets
633+
TermsMetrics.TermCount(term=bucket["key"], count=bucket["doc_count"]) for bucket in terms_buckets
614634
]
615-
return TermsMetadataMetrics(total=total_terms, values=terms_values)
635+
636+
return TermsMetrics(total=total_terms, values=terms_values)
637+
638+
async def _metrics_for_terms_property(
639+
self, index_name: str, metadata_property: MetadataProperty, query: Optional[dict] = None
640+
) -> TermsMetrics:
641+
field_name = es_field_for_metadata_property(metadata_property)
642+
return await self._compute_terms_metrics_for(index_name, field_name, query)
616643

617644
def _configure_index_mappings(self, dataset: Dataset) -> dict:
618645
return {
@@ -845,7 +872,7 @@ def _map_record_fields_to_es(cls, fields: dict, dataset_fields: List[Field]) ->
845872

846873
return fields
847874

848-
async def __terms_aggregation(self, index_name: str, field_name: str, query: dict, size: int) -> List[dict]:
875+
async def _terms_aggregation(self, index_name: str, field_name: str, query: dict, size: int) -> List[dict]:
849876
aggregation_name = "terms_agg"
850877

851878
terms_agg = {aggregation_name: {"terms": {"field": field_name, "size": min(size, self.max_terms_size)}}}

0 commit comments

Comments
 (0)