Skip to content

Commit 5875d59

Browse files
[ENHANCEMENT] argilla: Add limit argument when getting records (#5525)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR adds a new `limit` argument when getting records from the Argilla server. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: burtenshaw <[email protected]>
1 parent 86008cd commit 5875d59

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

argilla/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ These are the section headers that we use:
1616

1717
## [Unreleased]()
1818

19+
### Added
20+
21+
- Added `limit` argument when fetching records. ([#5525](https://github.com/argilla-io/argilla/pull/5525)
22+
1923
### Fixed
2024

2125
- Fixed the deployment yaml used to create a new Argilla server in K8s. Added `USERNAME` and `PASSWORD` to the environment variables of pod template. ([#5434](https://github.com/argilla-io/argilla/issues/5434))

argilla/src/argilla/records/_dataset_records.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
with_suggestions: bool = False,
5353
with_responses: bool = False,
5454
with_vectors: Optional[Union[str, List[str], bool]] = None,
55+
limit: Optional[int] = None,
5556
):
5657
self.__dataset = dataset
5758
self.__client = client
@@ -62,19 +63,41 @@ def __init__(
6263
self.__with_responses = with_responses
6364
self.__with_vectors = with_vectors
6465
self.__records_batch = []
66+
self.__limit = limit
67+
68+
if self.__limit is not None and self.__limit <= 0:
69+
warnings.warn(f"Limit {self.__limit} is invalid: must be greater than 0. Setting limit to 1.")
70+
self.__limit = 1
71+
72+
if self.__limit is not None and self.__limit < self.__batch_size:
73+
self.__batch_size = self.__limit
6574

6675
def __iter__(self):
6776
return self
6877

6978
def __next__(self) -> Record:
79+
if self._limit_reached():
80+
raise StopIteration()
81+
7082
if not self._has_local_records():
7183
self._fetch_next_batch()
7284
if not self._has_local_records():
7385
raise StopIteration()
86+
7487
return self._next_record()
7588

89+
def _limit_reached(self) -> bool:
90+
if self.__limit is None:
91+
return False
92+
return self.__limit <= 0
93+
7694
def _next_record(self) -> Record:
77-
return self.__records_batch.pop(0)
95+
record = self.__records_batch.pop(0)
96+
97+
if self.__limit is not None:
98+
self.__limit -= 1
99+
100+
return record
78101

79102
def _has_local_records(self) -> bool:
80103
return len(self.__records_batch) > 0
@@ -170,6 +193,7 @@ def __call__(
170193
with_suggestions: bool = True,
171194
with_responses: bool = True,
172195
with_vectors: Optional[Union[List, bool, str]] = None,
196+
limit: Optional[int] = None,
173197
) -> DatasetRecordsIterator:
174198
"""Returns an iterator over the records in the dataset on the server.
175199
@@ -182,6 +206,7 @@ def __call__(
182206
with_vectors: A list of vector names to include in the records. The default is None.
183207
If a list is provided, only the specified vectors will be included.
184208
If True is provided, all vectors will be included.
209+
limit: The maximum number of records to fetch. The default is None.
185210
186211
Returns:
187212
An iterator over the records in the dataset on the server.
@@ -202,6 +227,7 @@ def __call__(
202227
with_suggestions=with_suggestions,
203228
with_responses=with_responses,
204229
with_vectors=with_vectors,
230+
limit=limit,
205231
)
206232

207233
def __repr__(self) -> str:

argilla/tests/integration/test_list_records.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,93 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset):
6666
]
6767

6868

69+
def test_list_records_with_limit(client: Argilla, dataset: Dataset):
70+
dataset.records.log(
71+
[
72+
{"text": "The record text field", "id": 1},
73+
{"text": "The record text field", "id": 2},
74+
{"text": "The record text field", "id": 3},
75+
{"text": "The record text field", "id": 4},
76+
{"text": "The record text field", "id": 5},
77+
]
78+
)
79+
80+
records = list(dataset.records(limit=2))
81+
assert len(records) == 2
82+
83+
assert [record.to_dict() for record in records] == [
84+
{
85+
"_server_id": str(records[0]._server_id),
86+
"fields": {"text": "The record text field"},
87+
"id": "1",
88+
"status": "pending",
89+
"metadata": {},
90+
"responses": {},
91+
"suggestions": {},
92+
"vectors": {},
93+
},
94+
{
95+
"_server_id": str(records[1]._server_id),
96+
"fields": {"text": "The record text field"},
97+
"id": "2",
98+
"status": "pending",
99+
"metadata": {},
100+
"responses": {},
101+
"suggestions": {},
102+
"vectors": {},
103+
},
104+
]
105+
106+
107+
def test_list_records_with_limit_greater_than_batch_size(client: Argilla, dataset: Dataset):
108+
dataset.records.log(
109+
[
110+
{"text": "The record text field", "id": 1},
111+
{"text": "The record text field", "id": 2},
112+
{"text": "The record text field", "id": 3},
113+
{"text": "The record text field", "id": 4},
114+
{"text": "The record text field", "id": 5},
115+
]
116+
)
117+
118+
records = list(dataset.records(limit=2, batch_size=1))
119+
120+
assert len(records) == 2
121+
assert records[0].id == "1"
122+
assert records[1].id == "2"
123+
124+
125+
@pytest.mark.parametrize("limit", [0, -1, -10])
126+
def test_list_records_with_invalid_limit(client: Argilla, dataset: Dataset, limit: int):
127+
dataset.records.log(
128+
[
129+
{"text": "The record text field", "id": 1},
130+
{"text": "The record text field", "id": 2},
131+
{"text": "The record text field", "id": 3},
132+
{"text": "The record text field", "id": 4},
133+
{"text": "The record text field", "id": 5},
134+
]
135+
)
136+
with pytest.warns(UserWarning, match=f"Limit {limit} is invalid: must be greater than 0. Setting limit to 1."):
137+
records = list(dataset.records(limit=limit))
138+
assert len(records) == 1
139+
140+
141+
def test_list_records_with_limit_greater_than_total(client: Argilla, dataset: Dataset):
142+
dataset.records.log(
143+
[
144+
{"text": "The record text field", "id": 1},
145+
{"text": "The record text field", "id": 2},
146+
{"text": "The record text field", "id": 3},
147+
{"text": "The record text field", "id": 4},
148+
{"text": "The record text field", "id": 5},
149+
]
150+
)
151+
152+
records = list(dataset.records(limit=10))
153+
assert len(records) == 5
154+
155+
69156
def test_get_record_by_id(client: Argilla, dataset: Dataset):
70157
dataset.records.log(
71158
[

0 commit comments

Comments
 (0)