Skip to content

Commit 6452993

Browse files
authored
[ENHANCEMENT] Argilla SDK: Updating record fields and vectors (#5026)
<!-- Thanks for your contribution! As part of our Community Growers initiative 🌱, we're donating Justdiggit bunds in your name to reforest sub-Saharan Africa. To claim your Community Growers certificate, please contact David Berenstein in our Slack community or fill in this form https://tally.so/r/n9XrxK once your PR has been merged. --> # Description This PR reviews the record attributes and normalizes how to work with fields, vectors, and metadata. Now, all are treated as dictionaries and users can update in the same way that working with dictionaries or creating new attributes: ```python record = Record(fields={"name": "John"}) record.fields.update({"name": "Jane", "age": "30"}) record.fields.new_field = "value" record.vectors["new-vector"] = [1.0, 2.0, 3.0] record.vectors.vector = [1.0, 2.0, 3.0] record.metadata["new-key"] = "new_value" record.metadata.key = "new_value" ``` Once this approach is approved, I will create a new PR changing the docs. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] New feature (non-breaking change which adds functionality) - [x] Refactor (change restructuring the codebase without changing functionality) - [x] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I added relevant documentation - [ ] I followed the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the `CHANGELOG.md` file (See https://keepachangelog.com/)
1 parent 90f1ef6 commit 6452993

File tree

10 files changed

+154
-121
lines changed

10 files changed

+154
-121
lines changed

argilla/src/argilla/_models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from argilla._models._workspace import WorkspaceModel
1919
from argilla._models._user import UserModel, Role
2020
from argilla._models._dataset import DatasetModel
21-
from argilla._models._record._record import RecordModel
21+
from argilla._models._record._record import RecordModel, FieldValue
2222
from argilla._models._record._suggestion import SuggestionModel
2323
from argilla._models._record._response import UserResponseModel, ResponseStatus
24-
from argilla._models._record._vector import VectorModel
24+
from argilla._models._record._vector import VectorModel, VectorValue
2525
from argilla._models._record._metadata import MetadataModel, MetadataValue
2626
from argilla._models._search import (
2727
SearchQueryModel,

argilla/src/argilla/_models/_record/_record.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,20 @@
1616

1717
from pydantic import Field, field_serializer, field_validator
1818

19-
from argilla._models._resource import ResourceModel
2019
from argilla._models._record._metadata import MetadataModel, MetadataValue
2120
from argilla._models._record._response import UserResponseModel
2221
from argilla._models._record._suggestion import SuggestionModel
2322
from argilla._models._record._vector import VectorModel
23+
from argilla._models._resource import ResourceModel
24+
25+
__all__ = ["RecordModel", "FieldValue"]
2426

27+
FieldValue = Union[str, None]
2528

2629
class RecordModel(ResourceModel):
2730
"""Schema for the records of a `Dataset`"""
2831

29-
fields: Optional[Dict[str, Union[str, None]]] = None
32+
fields: Optional[Dict[str, FieldValue]] = None
3033
metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict)
3134
vectors: Optional[List[VectorModel]] = Field(default_factory=list)
3235
responses: Optional[List[UserResponseModel]] = Field(default_factory=list)
@@ -49,7 +52,7 @@ def serialize_metadata(self, value: List[MetadataModel]) -> Dict[str, Any]:
4952
return {metadata.name: metadata.value for metadata in value}
5053

5154
@field_serializer("fields", when_used="always")
52-
def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Dict[str, Union[str, None]]:
55+
def serialize_empty_fields(self, value: Dict[str, Union[str, None]]) -> Optional[Dict[str, Union[str, None]]]:
5356
"""Serialize empty fields to None."""
5457
if isinstance(value, dict) and len(value) == 0:
5558
return None

argilla/src/argilla/_models/_record/_vector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415
from typing import List
1516

17+
from pydantic import field_validator
18+
1619
from argilla._models import ResourceModel
1720

18-
import re
19-
from pydantic import field_validator
21+
__all__ = ["VectorModel", "VectorValue"]
2022

21-
__all__ = ["VectorModel"]
23+
VectorValue = List[float]
2224

2325

2426
class VectorModel(ResourceModel):
2527
name: str
26-
vector_values: List[float]
28+
vector_values: VectorValue
2729

2830
@field_validator("name")
2931
@classmethod

argilla/src/argilla/records/_dataset_records.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from argilla._api import RecordsAPI
2323
from argilla._helpers import LoggingMixin
24-
from argilla._models import RecordModel, MetadataValue
24+
from argilla._models import RecordModel, MetadataValue, VectorValue, FieldValue
2525
from argilla.client import Argilla
2626
from argilla.records._io import GenericIO, HFDataset, HFDatasetsIO, JsonIO
2727
from argilla.records._resource import Record
@@ -405,13 +405,15 @@ def _infer_record_from_mapping(
405405
Returns:
406406
A Record object.
407407
"""
408-
fields: Dict[str, str] = {}
409-
responses: List[Response] = []
410408
record_id: Optional[str] = None
411-
suggestion_values = defaultdict(dict)
412-
vectors: List[Vector] = []
409+
410+
fields: Dict[str, FieldValue] = {}
411+
vectors: Dict[str, VectorValue] = {}
413412
metadata: Dict[str, MetadataValue] = {}
414413

414+
responses: List[Response] = []
415+
suggestion_values: Dict[str, dict] = defaultdict(dict)
416+
415417
schema = self.__dataset.schema
416418

417419
for attribute, value in data.items():
@@ -466,7 +468,7 @@ def _infer_record_from_mapping(
466468
{"value": value, "question_name": attribute, "question_id": schema_item.id}
467469
)
468470
elif isinstance(schema_item, VectorField):
469-
vectors.append(Vector(name=attribute, values=value))
471+
vectors[attribute] = value
470472
elif isinstance(schema_item, MetadataPropertyBase):
471473
metadata[attribute] = value
472474
else:
@@ -478,9 +480,9 @@ def _infer_record_from_mapping(
478480
return Record(
479481
id=record_id,
480482
fields=fields,
481-
suggestions=suggestions,
482-
responses=responses,
483483
vectors=vectors,
484484
metadata=metadata,
485+
suggestions=suggestions,
486+
responses=responses,
485487
_dataset=self.__dataset,
486488
)

argilla/src/argilla/records/_resource.py

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
SuggestionModel,
2424
VectorModel,
2525
MetadataValue,
26+
FieldValue,
27+
VectorValue,
2628
)
2729
from argilla._resource import Resource
2830
from argilla.responses import Response, UserResponse
@@ -54,9 +56,9 @@ class Record(Resource):
5456
def __init__(
5557
self,
5658
id: Optional[Union[UUID, str]] = None,
57-
fields: Optional[Dict[str, Union[str, None]]] = None,
59+
fields: Optional[Dict[str, FieldValue]] = None,
5860
metadata: Optional[Dict[str, MetadataValue]] = None,
59-
vectors: Optional[List[Vector]] = None,
61+
vectors: Optional[Dict[str, VectorValue]] = None,
6062
responses: Optional[List[Response]] = None,
6163
suggestions: Optional[List[Suggestion]] = None,
6264
_server_id: Optional[UUID] = None,
@@ -93,7 +95,7 @@ def __init__(
9395
# Initialize the fields
9496
self.__fields = RecordFields(fields=self._model.fields)
9597
# Initialize the vectors
96-
self.__vectors = RecordVectors(vectors=vectors, record=self)
98+
self.__vectors = RecordVectors(vectors=vectors)
9799
# Initialize the metadata
98100
self.__metadata = RecordMetadata(metadata=metadata)
99101
self.__responses = RecordResponses(responses=responses, record=self)
@@ -158,8 +160,8 @@ def api_model(self) -> RecordModel:
158160
id=self._model.id,
159161
external_id=self._model.external_id,
160162
fields=self.fields.to_dict(),
161-
metadata=self.metadata.models,
162-
vectors=self.vectors.models,
163+
metadata=self.metadata.api_models(),
164+
vectors=self.vectors.api_models(),
163165
responses=self.responses.api_models(),
164166
suggestions=self.suggestions.api_models(),
165167
)
@@ -181,19 +183,22 @@ def to_dict(self) -> Dict[str, Dict]:
181183
represented as a key-value pair in the dictionary of the respective key. i.e.
182184
`{"fields": {"prompt": "...", "response": "..."}, "responses": {"rating": "..."},
183185
"""
186+
id = str(self.id) if self.id else None
187+
server_id = str(self._model.id) if self._model.id else None
184188
fields = self.fields.to_dict()
185-
metadata = dict(self.metadata)
189+
metadata = self.metadata.to_dict()
186190
suggestions = self.suggestions.to_dict()
187191
responses = self.responses.to_dict()
188192
vectors = self.vectors.to_dict()
193+
189194
return {
190-
"id": self.id,
195+
"id": id,
191196
"fields": fields,
192197
"metadata": metadata,
193198
"suggestions": suggestions,
194199
"responses": responses,
195200
"vectors": vectors,
196-
"_server_id": str(self._model.id) if self._model.id else None,
201+
"_server_id": server_id,
197202
}
198203

199204
@classmethod
@@ -219,7 +224,6 @@ def from_dict(cls, data: Dict[str, Dict], dataset: Optional["Dataset"] = None) -
219224
for question_name, _responses in responses.items()
220225
for value in _responses
221226
]
222-
vectors = [Vector(name=vector_name, values=values) for vector_name, values in vectors.items()]
223227

224228
return cls(
225229
id=record_id,
@@ -245,7 +249,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
245249
id=model.external_id,
246250
fields=model.fields,
247251
metadata={meta.name: meta.value for meta in model.metadata},
248-
vectors=[Vector.from_model(model=vector) for vector in model.vectors],
252+
vectors={vector.name: vector.vector_values for vector in model.vectors},
249253
# Responses and their models are not aligned 1-1.
250254
responses=[
251255
response
@@ -258,27 +262,62 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
258262
)
259263

260264

261-
class RecordFields:
265+
class RecordFields(dict):
262266
"""This is a container class for the fields of a Record.
263-
It allows for accessing fields by attribute and iterating over them.
267+
It allows for accessing fields by attribute and key name.
264268
"""
265269

266-
def __init__(self, fields: Dict[str, Union[str, None]]) -> None:
267-
self.__fields = fields or {}
268-
for key, value in self.__fields.items():
269-
setattr(self, key, value)
270+
def __init__(self, fields: Optional[Dict[str, FieldValue]] = None) -> None:
271+
super().__init__(fields or {})
270272

271-
def __getitem__(self, key: str) -> Optional[str]:
272-
return self.__fields.get(key)
273+
def __getattr__(self, item: str):
274+
return self[item]
273275

274-
def __iter__(self):
275-
return iter(self.__fields)
276+
def __setattr__(self, key: str, value: MetadataValue):
277+
self[key] = value
276278

277-
def to_dict(self) -> Dict[str, Union[str, None]]:
278-
return self.__fields
279+
def to_dict(self) -> dict:
280+
return dict(self.items())
279281

280-
def __repr__(self) -> str:
281-
return self.to_dict().__repr__()
282+
283+
class RecordMetadata(dict):
284+
"""This is a container class for the metadata of a Record."""
285+
286+
def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None:
287+
super().__init__(metadata or {})
288+
289+
def __getattr__(self, item: str):
290+
return self[item]
291+
292+
def __setattr__(self, key: str, value: MetadataValue):
293+
self[key] = value
294+
295+
def to_dict(self) -> dict:
296+
return dict(self.items())
297+
298+
def api_models(self) -> List[MetadataModel]:
299+
return [MetadataModel(name=key, value=value) for key, value in self.items()]
300+
301+
302+
class RecordVectors(dict):
303+
"""This is a container class for the vectors of a Record.
304+
It allows for accessing suggestions by attribute and key name.
305+
"""
306+
307+
def __init__(self, vectors: Dict[str, VectorValue]) -> None:
308+
super().__init__(vectors or {})
309+
310+
def __getattr__(self, item: str):
311+
return self[item]
312+
313+
def __setattr__(self, key: str, value: VectorValue):
314+
self[key] = value
315+
316+
def to_dict(self) -> Dict[str, List[float]]:
317+
return dict(self.items())
318+
319+
def api_models(self) -> List[VectorModel]:
320+
return [Vector(name=name, values=value).api_model() for name, value in self.items()]
282321

283322

284323
class RecordResponses(Iterable[Response]):
@@ -309,6 +348,16 @@ def __getattr__(self, name) -> List[Response]:
309348
def __repr__(self) -> str:
310349
return {k: [{"value": v["value"]} for v in values] for k, values in self.to_dict().items()}.__repr__()
311350

351+
def to_dict(self) -> Dict[str, List[Dict]]:
352+
"""Converts the responses to a dictionary.
353+
Returns:
354+
A dictionary of responses.
355+
"""
356+
response_dict = defaultdict(list)
357+
for response in self.__responses:
358+
response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)})
359+
return response_dict
360+
312361
def api_models(self) -> List[UserResponseModel]:
313362
"""Returns a list of ResponseModel objects."""
314363

@@ -321,15 +370,6 @@ def api_models(self) -> List[UserResponseModel]:
321370
for responses in responses_by_user_id.values()
322371
]
323372

324-
def to_dict(self) -> Dict[str, List[Dict]]:
325-
"""Converts the responses to a dictionary.
326-
Returns:
327-
A dictionary of responses.
328-
"""
329-
response_dict = defaultdict(list)
330-
for response in self.__responses:
331-
response_dict[response.question_name].append({"value": response.value, "user_id": response.user_id})
332-
return response_dict
333373

334374

335375
class RecordSuggestions(Iterable[Suggestion]):
@@ -345,15 +385,15 @@ def __init__(self, suggestions: List[Suggestion], record: Record) -> None:
345385
suggestion.record = self.record
346386
setattr(self, suggestion.question_name, suggestion)
347387

348-
def api_models(self) -> List[SuggestionModel]:
349-
return [suggestion.api_model() for suggestion in self.__suggestions]
350-
351388
def __iter__(self):
352389
return iter(self.__suggestions)
353390

354391
def __getitem__(self, index: int):
355392
return self.__suggestions[index]
356393

394+
def __repr__(self) -> str:
395+
return self.to_dict().__repr__()
396+
357397
def to_dict(self) -> Dict[str, List[str]]:
358398
"""Converts the suggestions to a dictionary.
359399
Returns:
@@ -368,48 +408,6 @@ def to_dict(self) -> Dict[str, List[str]]:
368408
}
369409
return suggestion_dict
370410

371-
def __repr__(self) -> str:
372-
return self.to_dict().__repr__()
373-
374-
375-
class RecordVectors:
376-
"""This is a container class for the vectors of a Record.
377-
It allows for accessing suggestions by attribute and iterating over them.
378-
"""
379-
380-
def __init__(self, vectors: List[Vector], record: Record) -> None:
381-
self.__vectors = vectors or []
382-
self.record = record
383-
for vector in self.__vectors:
384-
setattr(self, vector.name, vector.values)
385-
386-
def __repr__(self) -> str:
387-
return {vector.name: f"{len(vector.values)}" for vector in self.__vectors}.__repr__()
388-
389-
@property
390-
def models(self) -> List[VectorModel]:
391-
return [vector.api_model() for vector in self.__vectors]
392-
393-
def to_dict(self) -> Dict[str, List[float]]:
394-
"""Converts the vectors to a dictionary.
395-
Returns:
396-
A dictionary of vectors.
397-
"""
398-
return {vector.name: list(map(float, vector.values)) for vector in self.__vectors}
399-
400-
401-
class RecordMetadata(dict):
402-
"""This is a container class for the metadata of a Record."""
403-
404-
def __init__(self, metadata: Optional[Dict[str, MetadataValue]] = None) -> None:
405-
super().__init__(metadata or {})
406-
407-
def __getattr__(self, item: str):
408-
return self[item]
409-
410-
def __setattr__(self, key: str, value: MetadataValue):
411-
self[key] = value
411+
def api_models(self) -> List[SuggestionModel]:
412+
return [suggestion.api_model() for suggestion in self.__suggestions]
412413

413-
@property
414-
def models(self) -> List[MetadataModel]:
415-
return [MetadataModel(name=key, value=value) for key, value in self.items()]

0 commit comments

Comments
 (0)