Skip to content

Commit dc3deab

Browse files
Saikiranbonu1661Sai Kiran Bonupre-commit-ci[bot]frascuchon
authored
Fix to Return Similarity Scores Along with Records (#5778)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> Closes #5777 **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> Passes similar param in query method while fetching records and could see returning similarity score. **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Sai Kiran Bonu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Francisco Aranda <[email protected]> Co-authored-by: Paco Aranda <[email protected]>
1 parent 8efa39c commit dc3deab

File tree

8 files changed

+83
-24
lines changed

8 files changed

+83
-24
lines changed

argilla/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ These are the section headers that we use:
1616

1717
## [Unreleased]()
1818

19+
### Added
20+
21+
- Return similarity score when searching by similarity. ([#5778](https://github.com/argilla-io/argilla/pull/5778))
22+
1923
## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0)
2024

2125
### Fixed

argilla/src/argilla/records/_dataset_records.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import warnings
1515
from pathlib import Path
16-
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union
16+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
1717
from uuid import UUID
1818
from enum import Enum
1919

@@ -86,7 +86,7 @@ def _limit_reached(self) -> bool:
8686
return False
8787
return self.__limit <= 0
8888

89-
def _next_record(self) -> Record:
89+
def _next_record(self) -> Union[Record, Tuple[Record, float]]:
9090
if self._limit_reached() or self._no_records():
9191
raise StopIteration()
9292

@@ -104,15 +104,23 @@ def _fetch_next_batch(self) -> None:
104104
self.__records_batch = list(self._list())
105105
self.__offset += len(self.__records_batch)
106106

107-
def _list(self) -> Sequence[Record]:
108-
for record_model in self._fetch_from_server():
109-
yield Record.from_model(model=record_model, dataset=self.__dataset)
110-
111-
def _fetch_from_server(self) -> List[RecordModel]:
107+
def _list(self) -> Sequence[Union[Record, Tuple[Record, float]]]:
112108
if not self.__client.api.datasets.exists(self.__dataset.id):
113109
warnings.warn(f"Dataset {self.__dataset.id!r} does not exist on the server. Skipping...")
114110
return []
115-
return self._fetch_from_server_with_search() if self._is_search_query() else self._fetch_from_server_with_list()
111+
112+
if self._is_search_query():
113+
records = self._fetch_from_server_with_search()
114+
115+
if self.__query.has_similar():
116+
for record_model, score in records:
117+
yield Record.from_model(model=record_model, dataset=self.__dataset), score
118+
else:
119+
for record_model, _ in records:
120+
yield Record.from_model(model=record_model, dataset=self.__dataset)
121+
else:
122+
for record_model in self._fetch_from_server_with_list():
123+
yield Record.from_model(model=record_model, dataset=self.__dataset)
116124

117125
def _fetch_from_server_with_list(self) -> List[RecordModel]:
118126
return self.__client.api.records.list(
@@ -124,7 +132,7 @@ def _fetch_from_server_with_list(self) -> List[RecordModel]:
124132
with_vectors=self.__with_vectors,
125133
)
126134

127-
def _fetch_from_server_with_search(self) -> List[RecordModel]:
135+
def _fetch_from_server_with_search(self) -> List[Tuple[RecordModel, float]]:
128136
search_items, total = self.__client.api.records.search(
129137
dataset_id=self.__dataset.id,
130138
query=self.__query.api_model(),
@@ -134,7 +142,7 @@ def _fetch_from_server_with_search(self) -> List[RecordModel]:
134142
with_suggestions=self.__with_suggestions,
135143
with_vectors=self.__with_vectors,
136144
)
137-
return [record_model for record_model, _ in search_items]
145+
return search_items
138146

139147
def _is_search_query(self) -> bool:
140148
return self.__query.has_search()

argilla/src/argilla/records/_io/_datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import warnings
16-
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional
16+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, Tuple
1717

1818
from datasets import Dataset as HFDataset, Sequence
1919
from datasets import Image, ClassLabel, Value
@@ -194,7 +194,7 @@ def _is_hf_dataset(dataset: Any) -> bool:
194194
return isinstance(dataset, HFDataset)
195195

196196
@staticmethod
197-
def to_datasets(records: List["Record"], dataset: "Dataset") -> HFDataset:
197+
def to_datasets(records: List[Union["Record", Tuple["Record", float]]], dataset: "Dataset") -> HFDataset:
198198
"""
199199
Export the records to a Hugging Face dataset.
200200

argilla/src/argilla/records/_io/_generic.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from collections import defaultdict
16-
from typing import Any, Dict, List, TYPE_CHECKING, Union
16+
from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Union
1717

1818
if TYPE_CHECKING:
1919
from argilla import Record
@@ -24,7 +24,9 @@ class GenericIO:
2424
It handles methods for exporting records to generic python formats."""
2525

2626
@staticmethod
27-
def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Union[str, float, int, list]]]:
27+
def to_list(
28+
records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False
29+
) -> List[Dict[str, Union[str, float, int, list]]]:
2830
"""Export records to a list of dictionaries with either names or record index as keys.
2931
Args:
3032
flatten (bool): The structure of the exported dictionary.
@@ -48,7 +50,7 @@ def to_list(records: List["Record"], flatten: bool = False) -> List[Dict[str, Un
4850

4951
@classmethod
5052
def to_dict(
51-
cls, records: List["Record"], flatten: bool = False, orient: str = "names"
53+
cls, records: List[Union["Record", Tuple["Record", float]]], flatten: bool = False, orient: str = "names"
5254
) -> Dict[str, Union[str, float, int, list]]:
5355
"""Export records to a dictionary with either names or record index as keys.
5456
Args:
@@ -79,10 +81,10 @@ def to_dict(
7981
############################
8082

8183
@staticmethod
82-
def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
84+
def _record_to_dict(record: Union["Record", Tuple["Record", float]], flatten=False) -> Dict[str, Any]:
8385
"""Converts a Record object to a dictionary for export.
8486
Args:
85-
record (Record): The Record object to convert.
87+
record (Record): The Record object or the record and the linked score to convert.
8688
flatten (bool): The structure of the exported dictionary.
8789
- True: The record fields, metadata, suggestions and responses will be flattened
8890
so that their keys becomes the keys of the record dictionary, using
@@ -92,6 +94,12 @@ def _record_to_dict(record: "Record", flatten=False) -> Dict[str, Any]:
9294
Returns:
9395
A dictionary representing the record.
9496
"""
97+
if isinstance(record, tuple):
98+
record, score = record
99+
100+
record_dict = GenericIO._record_to_dict(record, flatten)
101+
record_dict["score"] = score
102+
return record_dict
95103

96104
record_dict = record.to_dict()
97105
if flatten:

argilla/src/argilla/records/_io/_json.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# limitations under the License.
1414
import json
1515
from pathlib import Path
16-
from typing import List, Union
16+
from typing import List, Tuple, Union
1717

1818
from argilla.records._resource import Record
1919
from argilla.records._io import GenericIO
2020

2121

2222
class JsonIO:
2323
@staticmethod
24-
def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
24+
def to_json(records: List[Union["Record", Tuple["Record", float]]], path: Union[Path, str]) -> Path:
2525
"""
2626
Export the records to a file on disk. This is a convenient shortcut for dataset.records(...).to_disk().
2727
@@ -37,6 +37,7 @@ def to_json(records: List["Record"], path: Union[Path, str]) -> Path:
3737
path = Path(path)
3838
if path.exists():
3939
raise FileExistsError(f"File {path} already exists.")
40+
4041
record_dicts = GenericIO.to_list(records, flatten=False)
4142
with open(path, "w") as f:
4243
json.dump(record_dicts, f)

argilla/src/argilla/records/_search.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ def __init__(
175175
self.similar = similar
176176

177177
def has_search(self) -> bool:
178-
return bool(self.query or self.similar or self.filter)
178+
return bool(self.query or self.has_similar() or self.filter)
179+
180+
def has_similar(self) -> bool:
181+
return bool(self.similar)
179182

180183
def api_model(self) -> SearchQueryModel:
181184
model = SearchQueryModel()

argilla/tests/integration/test_search_records.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_search_records_by_similar_value(self, client: Argilla, dataset: Dataset
173173
)
174174
)
175175
assert len(records) == 1000
176-
assert records[0].id == str(data[3]["id"])
176+
assert records[0][0].id == str(data[3]["id"])
177177

178178
def test_search_records_by_least_similar_value(self, client: Argilla, dataset: Dataset):
179179
data = [
@@ -194,7 +194,7 @@ def test_search_records_by_least_similar_value(self, client: Argilla, dataset: D
194194
)
195195
)
196196
)
197-
assert records[-1].id == str(data[3]["id"])
197+
assert records[-1][0].id == str(data[3]["id"])
198198

199199
def test_search_records_by_similar_record(self, client: Argilla, dataset: Dataset):
200200
data = [
@@ -218,7 +218,7 @@ def test_search_records_by_similar_record(self, client: Argilla, dataset: Datase
218218
)
219219
)
220220
assert len(records) == 1000
221-
assert records[0].id != str(record.id)
221+
assert records[0][0].id != str(record.id)
222222

223223
def test_search_records_by_least_similar_record(self, client: Argilla, dataset: Dataset):
224224
data = [
@@ -241,4 +241,4 @@ def test_search_records_by_least_similar_record(self, client: Argilla, dataset:
241241
)
242242
)
243243
)
244-
assert all(r.id != str(record.id) for r in records)
244+
assert all(r.id != str(record.id) for r, s in records)

argilla/tests/unit/test_io/test_generic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,38 @@ def test_to_list_flatten(self):
6060
"q2.suggestion.agent": None,
6161
}
6262
]
63+
64+
def test_records_tuple_to_list(self):
65+
record = rg.Record(fields={"field": "The field"}, metadata={"key": "value"})
66+
67+
records_list = GenericIO.to_list(
68+
[
69+
(record, 1.0),
70+
(record, 0.5),
71+
]
72+
)
73+
74+
assert records_list == [
75+
{
76+
"id": str(record.id),
77+
"status": record.status,
78+
"_server_id": record._server_id,
79+
"fields": {"field": "The field"},
80+
"metadata": {"key": "value"},
81+
"responses": {},
82+
"vectors": {},
83+
"suggestions": {},
84+
"score": 1.0,
85+
},
86+
{
87+
"id": str(record.id),
88+
"status": record.status,
89+
"_server_id": record._server_id,
90+
"fields": {"field": "The field"},
91+
"metadata": {"key": "value"},
92+
"responses": {},
93+
"vectors": {},
94+
"suggestions": {},
95+
"score": 0.5,
96+
},
97+
]

0 commit comments

Comments
 (0)