Skip to content

Commit bafd92f

Browse files
frascuchonpre-commit-ci[bot]nataliaElv
authored
[ENHANCEMENT] expose dataset progress values within the python SDK (#5479)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> Closes #5476 **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Natalia Elvira <[email protected]>
1 parent f5ff647 commit bafd92f

File tree

5 files changed

+167
-1
lines changed

5 files changed

+167
-1
lines changed

argilla/docs/how_to_guides/annotate.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ You can track the progress of an annotation task in the progress bar shown in th
125125

126126
You can also track your own progress in real time expanding the right-bottom panel inside the dataset page. There you can see the number of records for which you have `Pending``Draft``Submitted` and `Discarded` responses.
127127

128+
!!! note
129+
You can also explore the dataset progress from the SDK. Check the [Track your team's progress](./distribution.md#track-your-teams-progress) to know more about it.
130+
128131
## Use search, filters, and sort
129132

130133
The UI offers various features designed for data exploration and understanding. Combining these features with bulk labelling can save you and your team hours of time.

argilla/docs/how_to_guides/distribution.md

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,61 @@ dataset = client.datasets("my_dataset")
7777
dataset.settings.distribution.min_submitted = 4
7878

7979
dataset.update()
80-
```
80+
```
81+
82+
## Track your team's progress
83+
84+
You can check the progress of the annotation task by using the `dataset.progress` method.
85+
This method will return the number of records that have the status `completed`, `pending`, and the
86+
total number of records in the dataset.
87+
88+
```python
89+
import argilla as rg
90+
91+
client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")
92+
93+
dataset = client.datasets("my_dataset")
94+
95+
progress = dataset.progress()
96+
```
97+
```json
98+
{
99+
"total": 100,
100+
"completed": 10,
101+
"pending": 90
102+
}
103+
```
104+
105+
You can see also include to the progress the users distribution by setting the `with_users_distribution` parameter to `True`.
106+
This will return the number of records that have the status `completed`, `pending`, and the total number of records in the dataset,
107+
as well as the number of completed submissions per user. You can visit the [Annotation Progress](../how_to_guides/annotate.md#annotation-progress) section for more information.
108+
109+
```python
110+
import argilla as rg
111+
112+
client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")
113+
114+
dataset = client.datasets("my_dataset")
115+
116+
progress = dataset.progress(with_users_distribution=True)
117+
```
118+
```json
119+
{
120+
"total": 100,
121+
"completed": 50,
122+
"pending": 50,
123+
"users": {
124+
"user1": {
125+
"completed": { "submitted": 10, "draft": 5, "discarded": 5},
126+
"pending": { "submitted": 5, "draft": 10, "discarded": 10},
127+
},
128+
"user2": {
129+
"completed": { "submitted": 20, "draft": 10, "discarded": 5},
130+
"pending": { "submitted": 2, "draft": 25, "discarded": 0},
131+
},
132+
...
133+
}
134+
```
135+
136+
!!! note
137+
Since the completed records can contain submissions from multiple users, the number of completed submissions per user may not match the total number of completed records.

argilla/src/argilla/_api/_datasets.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
__all__ = ["DatasetsAPI"]
2424

25+
from argilla._models._dataset_progress import UserProgressModel, DatasetProgressModel
26+
2527

2628
class DatasetsAPI(ResourceAPI[DatasetModel]):
2729
"""Manage datasets via the API"""
@@ -80,6 +82,24 @@ def exists(self, dataset_id: UUID) -> bool:
8082
# Utility methods #
8183
####################
8284

85+
@api_error_handler
86+
def get_progress(self, dataset_id: UUID) -> DatasetProgressModel:
87+
response = self.http_client.get(f"{self.url_stub}/{dataset_id}/progress")
88+
response.raise_for_status()
89+
response_json = response.json()
90+
91+
self._log_message(message=f"Got progress for dataset {dataset_id}")
92+
return DatasetProgressModel.model_validate(response_json)
93+
94+
@api_error_handler
95+
def list_users_progress(self, dataset_id: UUID) -> List[UserProgressModel]:
96+
response = self.http_client.get(f"{self.url_stub}/{dataset_id}/users/progress")
97+
response.raise_for_status()
98+
response_json = response.json()
99+
100+
self._log_message(message=f"Got users progress for dataset {dataset_id}")
101+
return [UserProgressModel.model_validate(data) for data in response_json["users"]]
102+
83103
@api_error_handler
84104
def publish(self, dataset_id: UUID) -> "DatasetModel":
85105
response = self.http_client.put(url=f"{self.url_stub}/{dataset_id}/publish")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2024-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pydantic import BaseModel
16+
17+
18+
class DatasetProgressModel(BaseModel):
19+
"""Dataset progress model."""
20+
21+
total: int = 0
22+
completed: int = 0
23+
pending: int = 0
24+
25+
26+
class RecordResponseDistributionModel(BaseModel):
27+
"""Response distribution model."""
28+
29+
submitted: int = 0
30+
draft: int = 0
31+
discarded: int = 0
32+
33+
34+
class UserProgressModel(BaseModel):
35+
"""User progress model."""
36+
37+
username: str
38+
completed: RecordResponseDistributionModel = RecordResponseDistributionModel()
39+
pending: RecordResponseDistributionModel = RecordResponseDistributionModel()

argilla/src/argilla/datasets/_resource.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,53 @@ def update(self) -> "Dataset":
174174
self.settings.update()
175175
return self
176176

177+
def progress(self, with_users_distribution: bool = False) -> dict:
178+
"""Returns the team's progress on the dataset.
179+
180+
Parameters:
181+
with_users_distribution (bool): If True, the progress of the dataset is returned
182+
with users distribution. This includes the number of responses made by each user.
183+
184+
Returns:
185+
dict: The team's progress on the dataset.
186+
187+
An example of a response when `with_users_distribution` is `True`:
188+
```json
189+
{
190+
"total": 100,
191+
"completed": 50,
192+
"pending": 50,
193+
"users": {
194+
"user1": {
195+
"completed": { "submitted": 10, "draft": 5, "discarded": 5},
196+
"pending": { "submitted": 5, "draft": 10, "discarded": 10},
197+
},
198+
"user2": {
199+
"completed": { "submitted": 20, "draft": 10, "discarded": 5},
200+
"pending": { "submitted": 2, "draft": 25, "discarded": 0},
201+
},
202+
...
203+
}
204+
```
205+
206+
"""
207+
208+
progress = self._api.get_progress(dataset_id=self._model.id).model_dump()
209+
210+
if with_users_distribution:
211+
users_progress = self._api.list_users_progress(dataset_id=self._model.id)
212+
users_distribution = {
213+
user.username: {
214+
"completed": user.completed.model_dump(),
215+
"pending": user.pending.model_dump(),
216+
}
217+
for user in users_progress
218+
}
219+
220+
progress.update({"users": users_distribution})
221+
222+
return progress
223+
177224
@classmethod
178225
def from_model(cls, model: DatasetModel, client: "Argilla") -> "Dataset":
179226
instance = cls(client=client, workspace=model.workspace_id, name=model.name)

0 commit comments

Comments
 (0)