Skip to content

Commit 668643c

Browse files
authored
[REFACTOR] argilla server: moving all record validators (#5603)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR moves all the record relationship validations (suggestions, responses, and vectors) inside the record validator. This helps centralize the validations and prevents objects from flushing into the DB early. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent de153ea commit 668643c

File tree

11 files changed

+166
-145
lines changed

11 files changed

+166
-145
lines changed

argilla-server/src/argilla_server/bulk/records_bulk.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from typing import Dict, List, Sequence, Tuple, Union
1717
from uuid import UUID
1818

19+
from fastapi.encoders import jsonable_encoder
1920
from sqlalchemy import select
2021
from sqlalchemy.ext.asyncio import AsyncSession
2122
from sqlalchemy.orm import selectinload
22-
from fastapi.encoders import jsonable_encoder
2323

2424
from argilla_server.api.schemas.v1.records import RecordCreate, RecordUpsert
2525
from argilla_server.api.schemas.v1.records_bulk import (
@@ -31,18 +31,13 @@
3131
from argilla_server.api.schemas.v1.responses import UserResponseCreate
3232
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
3333
from argilla_server.contexts import distribution
34-
from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict
3534
from argilla_server.contexts.records import (
3635
fetch_records_by_external_ids_as_dict,
3736
fetch_records_by_ids_as_dict,
3837
)
39-
from argilla_server.errors.future import UnprocessableEntityError
4038
from argilla_server.models import Dataset, Record, Response, Suggestion, Vector, VectorSettings
4139
from argilla_server.search_engine import SearchEngine
4240
from argilla_server.validators.records import RecordsBulkCreateValidator, RecordsBulkUpsertValidator
43-
from argilla_server.validators.responses import ResponseCreateValidator
44-
from argilla_server.validators.suggestions import SuggestionCreateValidator
45-
from argilla_server.validators.vectors import VectorValidator
4641

4742

4843
class CreateRecordsBulk:
@@ -84,30 +79,16 @@ async def _upsert_records_relationships(self, records: List[Record], records_cre
8479
# https://github.com/sqlalchemy/sqlalchemy/discussions/9312
8580

8681
await self._upsert_records_suggestions(records_and_suggestions)
87-
await self._upsert_records_responses(records_and_responses)
8882
await self._upsert_records_vectors(records_and_vectors)
83+
await self._upsert_records_responses(records_and_responses)
8984

9085
async def _upsert_records_suggestions(
9186
self, records_and_suggestions: List[Tuple[Record, List[SuggestionCreate]]]
9287
) -> List[Suggestion]:
9388
upsert_many_suggestions = []
9489
for idx, (record, suggestions) in enumerate(records_and_suggestions):
95-
try:
96-
for suggestion_create in suggestions or []:
97-
question = record.dataset.question_by_id(suggestion_create.question_id)
98-
if question is None:
99-
raise ValueError(f"question with question_id={suggestion_create.question_id} does not exist")
100-
101-
try:
102-
SuggestionCreateValidator.validate(suggestion_create, question.parsed_settings, record)
103-
upsert_many_suggestions.append(dict(**suggestion_create.dict(), record_id=record.id))
104-
except (UnprocessableEntityError, ValueError) as ex:
105-
raise ValueError(f"suggestion for question name={question.name} is not valid: {ex}")
106-
107-
except (UnprocessableEntityError, ValueError) as ex:
108-
raise UnprocessableEntityError(
109-
f"Record at position {idx} does not have valid suggestions because {ex}"
110-
) from ex
90+
for suggestion_create in suggestions or []:
91+
upsert_many_suggestions.append(dict(**suggestion_create.dict(), record_id=record.id))
11192

11293
if not upsert_many_suggestions:
11394
return []
@@ -122,22 +103,10 @@ async def _upsert_records_suggestions(
122103
async def _upsert_records_responses(
123104
self, records_and_responses: List[Tuple[Record, List[UserResponseCreate]]]
124105
) -> List[Response]:
125-
user_ids = [response.user_id for _, responses in records_and_responses for response in responses or []]
126-
users_by_id = await fetch_users_by_ids_as_dict(self._db, user_ids)
127-
128106
upsert_many_responses = []
129107
for idx, (record, responses) in enumerate(records_and_responses):
130-
try:
131-
for response_create in responses or []:
132-
if response_create.user_id not in users_by_id:
133-
raise ValueError(f"user with id {response_create.user_id} not found")
134-
135-
ResponseCreateValidator.validate(response_create, record)
136-
upsert_many_responses.append(dict(**response_create.dict(), record_id=record.id))
137-
except (UnprocessableEntityError, ValueError) as ex:
138-
raise UnprocessableEntityError(
139-
f"Record at position {idx} does not have valid responses because {ex}"
140-
) from ex
108+
for response_create in responses or []:
109+
upsert_many_responses.append(dict(**response_create.dict(), record_id=record.id))
141110

142111
if not upsert_many_responses:
143112
return []
@@ -154,18 +123,11 @@ async def _upsert_records_vectors(
154123
) -> List[Vector]:
155124
upsert_many_vectors = []
156125
for idx, (record, vectors) in enumerate(records_and_vectors):
157-
try:
158-
for name, value in (vectors or {}).items():
159-
settings = _get_vector_settings_by_name(record.dataset, name)
160-
if not settings:
161-
raise ValueError(f"vector with name={name} does not exist for dataset_id={record.dataset.id}")
162-
163-
VectorValidator.validate(value, settings)
164-
upsert_many_vectors.append(dict(value=value, record_id=record.id, vector_settings_id=settings.id))
165-
except (UnprocessableEntityError, ValueError) as ex:
166-
raise UnprocessableEntityError(
167-
f"Record at position {idx} does not have valid vectors because {ex}"
168-
) from ex
126+
dataset = record.dataset
127+
128+
for name, value in (vectors or {}).items():
129+
settings = dataset.vector_settings_by_name(name)
130+
upsert_many_vectors.append(dict(value=value, record_id=record.id, vector_settings_id=settings.id))
169131

170132
if not upsert_many_vectors:
171133
return []
@@ -186,7 +148,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp
186148
found_records = await self._fetch_existing_dataset_records(dataset, bulk_upsert.items)
187149
# found_records is passed to the validator to avoid querying the database again, but ideally, it should be
188150
# computed inside the validator
189-
RecordsBulkUpsertValidator.validate(bulk_upsert, dataset, found_records)
151+
await RecordsBulkUpsertValidator.validate(bulk_upsert, dataset, found_records)
190152

191153
records = []
192154
async with self._db.begin_nested():
@@ -246,9 +208,3 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record
246208
selectinload(Record.vectors),
247209
)
248210
)
249-
250-
251-
def _get_vector_settings_by_name(dataset: Dataset, name: str) -> Union[VectorSettings, None]:
252-
for vector_settings in dataset.vectors_settings:
253-
if vector_settings.name == name:
254-
return vector_settings

argilla-server/src/argilla_server/contexts/accounts.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import secrets
15-
from typing import Dict, Iterable, List, Sequence, Union
15+
from typing import Iterable, List, Sequence, Union
1616
from uuid import UUID
1717

1818
from passlib.context import CryptContext
@@ -193,8 +193,3 @@ def generate_user_token(user: User) -> str:
193193
role=user.role,
194194
),
195195
)
196-
197-
198-
async def fetch_users_by_ids_as_dict(db: "AsyncSession", user_ids: List[UUID]) -> Dict[UUID, User]:
199-
users = await list_users_by_ids(db, set(user_ids))
200-
return {user.id: user for user in users}

argilla-server/src/argilla_server/models/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from typing import Optional
1515
from uuid import UUID, uuid4
1616

17-
from sqlalchemy.ext.asyncio import AsyncAttrs
17+
from sqlalchemy.ext.asyncio import AsyncAttrs, async_object_session, AsyncSession
1818
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
1919

2020
from argilla_server.models.mixins import CRUDMixin, TimestampMixin
@@ -32,3 +32,7 @@ class DatabaseModel(DeclarativeBase, AsyncAttrs, CRUDMixin, TimestampMixin):
3232

3333
def is_relationship_loaded(self, relationship: str) -> bool:
3434
return relationship in self.__dict__
35+
36+
@property
37+
def current_async_session(self) -> Optional[AsyncSession]:
38+
return async_object_session(self)

argilla-server/src/argilla_server/models/database.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
sql,
2929
)
3030
from sqlalchemy.engine.default import DefaultExecutionContext
31-
from sqlalchemy.ext.asyncio import async_object_session
3231
from sqlalchemy.ext.mutable import MutableDict, MutableList
3332
from sqlalchemy.orm import Mapped, mapped_column, relationship
3433

@@ -509,7 +508,7 @@ def is_annotator(self):
509508
async def is_member(self, workspace_id: UUID) -> bool:
510509
# TODO: Change query to use exists may improve performance
511510
return (
512-
await WorkspaceUser.get_by(async_object_session(self), workspace_id=workspace_id, user_id=self.id)
511+
await WorkspaceUser.get_by(self.current_async_session, workspace_id=workspace_id, user_id=self.id)
513512
is not None
514513
)
515514

0 commit comments

Comments
 (0)