Skip to content

Commit 0487936

Browse files
authored
[FEATURE] Add support to update record fields (#5685)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR adds backend support to update record fields. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 21d07c9 commit 0487936

File tree

14 files changed

+397
-403
lines changed

14 files changed

+397
-403
lines changed

argilla-server/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ These are the section headers that we use:
1616

1717
## [Unreleased]()
1818

19+
### Added
20+
21+
- Added support to update record fields in `PATCH /api/v1/records/:record_id` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))
22+
- Added support to update record fields in `PUT /api/v1/datasets/:dataset_id/records/bulk` endpoint. ([#5685](https://github.com/argilla-io/argilla/pull/5685))
23+
1924
## [2.5.0](https://github.com/argilla-io/argilla/compare/v2.4.1...v2.5.0)
2025

2126
### Added

argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ async def delete_dataset_records(
264264
if num_records > DELETE_DATASET_RECORDS_LIMIT:
265265
raise UnprocessableEntityError(f"Cannot delete more than {DELETE_DATASET_RECORDS_LIMIT} records at once")
266266

267-
await datasets.delete_records(db, search_engine, dataset, record_ids)
267+
await records.delete_records(db, search_engine, dataset, record_ids)
268268

269269

270270
@router.post(

argilla-server/src/argilla_server/api/handlers/v1/records.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from argilla_server.api.schemas.v1.responses import Response, ResponseCreate
2626
from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema
2727
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions
28-
from argilla_server.contexts import datasets
28+
from argilla_server.contexts import datasets, records
2929
from argilla_server.database import get_async_db
3030
from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError
3131
from argilla_server.models import Dataset, Question, Record, Suggestion, User
@@ -74,16 +74,21 @@ async def update_record(
7474
db,
7575
record_id,
7676
options=[
77-
selectinload(Record.dataset).selectinload(Dataset.questions),
78-
selectinload(Record.dataset).selectinload(Dataset.metadata_properties),
77+
selectinload(Record.dataset).options(
78+
selectinload(Dataset.questions),
79+
selectinload(Dataset.metadata_properties),
80+
selectinload(Dataset.vectors_settings),
81+
selectinload(Dataset.fields),
82+
),
7983
selectinload(Record.suggestions),
84+
selectinload(Record.responses),
8085
selectinload(Record.vectors),
8186
],
8287
)
8388

8489
await authorize(current_user, RecordPolicy.update(record))
8590

86-
return await datasets.update_record(db, search_engine, record, record_update)
91+
return await records.update_record(db, search_engine, record, record_update)
8792

8893

8994
@router.post("/records/{record_id}/responses", status_code=status.HTTP_201_CREATED, response_model=Response)
@@ -233,4 +238,4 @@ async def delete_record(
233238

234239
await authorize(current_user, RecordPolicy.delete(record))
235240

236-
return await datasets.delete_record(db, search_engine, record)
241+
return await records.delete_record(db, search_engine, record)

argilla-server/src/argilla_server/api/schemas/v1/records.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
BaseModel,
2626
Field,
2727
StrictStr,
28-
root_validator,
29-
validator,
3028
ValidationError,
3129
ConfigDict,
3230
model_validator,
@@ -183,17 +181,12 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict
183181

184182

185183
class RecordUpdate(UpdateSchema):
186-
metadata_: Optional[Dict[str, Any]] = Field(None, alias="metadata")
184+
fields: Optional[Dict[str, FieldValueCreate]] = None
185+
metadata: Optional[Dict[str, Any]] = None
187186
suggestions: Optional[List[SuggestionCreate]] = None
188187
vectors: Optional[Dict[str, List[float]]] = None
189188

190-
@property
191-
def metadata(self) -> Optional[Dict[str, Any]]:
192-
# Align with the RecordCreate model. Both should have the same name for the metadata field.
193-
# TODO(@frascuchon): This will be properly adapted once the bulk records refactor is completed.
194-
return self.metadata_
195-
196-
@field_validator("metadata_")
189+
@field_validator("metadata")
197190
@classmethod
198191
def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
199192
if metadata is None:
@@ -205,15 +198,20 @@ def prevent_nan_values(cls, metadata: Optional[Dict[str, Any]]) -> Optional[Dict
205198

206199
return {k: v for k, v in metadata.items() if v == v} # By definition, NaN != NaN
207200

201+
def is_set(self, attribute: str) -> bool:
202+
return attribute in self.model_fields_set
208203

209-
class RecordUpdateWithId(RecordUpdate):
210-
id: UUID
204+
def has_changes(self) -> bool:
205+
return self.model_dump(exclude_unset=True) != {}
211206

212207

213208
class RecordUpsert(RecordCreate):
214209
id: Optional[UUID] = None
215210
fields: Optional[Dict[str, FieldValueCreate]] = None
216211

212+
def is_set(self, attribute: str) -> bool:
213+
return attribute in self.model_fields_set
214+
217215

218216
class RecordIncludeParam(BaseModel):
219217
relationships: Optional[List[RecordInclude]] = Field(None, alias="keys")
@@ -278,13 +276,6 @@ class RecordsCreate(BaseModel):
278276
items: List[RecordCreate] = Field(..., min_length=RECORDS_CREATE_MIN_ITEMS, max_length=RECORDS_CREATE_MAX_ITEMS)
279277

280278

281-
class RecordsUpdate(BaseModel):
282-
# TODO: review this definition and align to create model
283-
items: List[RecordUpdateWithId] = Field(
284-
..., min_length=RECORDS_UPDATE_MIN_ITEMS, max_length=RECORDS_UPDATE_MAX_ITEMS
285-
)
286-
287-
288279
class MetadataParsedQueryParam:
289280
def __init__(self, string: str):
290281
k, *v = string.split(":", maxsplit=1)

argilla-server/src/argilla_server/api/schemas/v1/records_bulk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class RecordsBulk(BaseModel):
3030
items: List[Record]
3131

3232

33-
class RecordsBulkWithUpdateInfo(RecordsBulk):
33+
class RecordsBulkWithUpdatedItemIds(RecordsBulk):
3434
updated_item_ids: List[UUID]
3535

3636

argilla-server/src/argilla_server/bulk/records_bulk.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Dict, List, Sequence, Tuple, Union
1717
from uuid import UUID
1818

19+
from datetime import UTC
1920
from fastapi.encoders import jsonable_encoder
2021
from sqlalchemy import select
2122
from sqlalchemy.ext.asyncio import AsyncSession
@@ -26,7 +27,7 @@
2627
RecordsBulk,
2728
RecordsBulkCreate,
2829
RecordsBulkUpsert,
29-
RecordsBulkWithUpdateInfo,
30+
RecordsBulkWithUpdatedItemIds,
3031
)
3132
from argilla_server.api.schemas.v1.responses import UserResponseCreate
3233
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
@@ -39,7 +40,7 @@
3940
fetch_records_by_ids_as_dict,
4041
)
4142
from argilla_server.errors.future import UnprocessableEntityError
42-
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings
43+
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector
4344
from argilla_server.search_engine import SearchEngine
4445
from argilla_server.validators.records import RecordsBulkCreateValidator, RecordUpsertValidator
4546

@@ -154,15 +155,11 @@ async def _upsert_records_vectors(
154155
autocommit=False,
155156
)
156157

157-
@classmethod
158-
def _metadata_is_set(cls, record_create: RecordCreate) -> bool:
159-
return "metadata" in record_create.model_fields_set
160-
161158

162159
class UpsertRecordsBulk(CreateRecordsBulk):
163160
async def upsert_records_bulk(
164161
self, dataset: Dataset, bulk_upsert: RecordsBulkUpsert, raise_on_error: bool = True
165-
) -> RecordsBulkWithUpdateInfo:
162+
) -> RecordsBulkWithUpdatedItemIds:
166163
found_records = await self._fetch_existing_dataset_records(dataset, bulk_upsert.items)
167164

168165
records = []
@@ -185,9 +182,14 @@ async def upsert_records_bulk(
185182
external_id=record_upsert.external_id,
186183
dataset_id=dataset.id,
187184
)
188-
elif self._metadata_is_set(record_upsert):
189-
record.metadata_ = record_upsert.metadata
190-
record.updated_at = datetime.utcnow()
185+
else:
186+
if record_upsert.is_set("metadata"):
187+
record.metadata_ = record_upsert.metadata
188+
if record_upsert.is_set("fields"):
189+
record.fields = jsonable_encoder(record_upsert.fields)
190+
191+
if self._db.is_modified(record):
192+
record.updated_at = datetime.now(UTC)
191193

192194
records.append(record)
193195

@@ -203,7 +205,7 @@ async def upsert_records_bulk(
203205

204206
await self._notify_upsert_record_events(records)
205207

206-
return RecordsBulkWithUpdateInfo(
208+
return RecordsBulkWithUpdatedItemIds(
207209
items=records,
208210
updated_item_ids=[record.id for record in found_records.values()],
209211
)

0 commit comments

Comments
 (0)