Skip to content

Commit b7ac946

Browse files
authored
[REFACTOR]: Unify validators signature (#5368)
# Description <!-- Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. --> This PR reviews and unifies the existing validators: 1. Validator class signature is a `validate` class method 2. All dependencies are passed as an argument to the `validate` method. 3. If validator requires access to the DB, the session db will be passed as the first `validate` argument. **Type of change** <!-- Please delete options that are not relevant. Remember to title the PR according to the type of change --> - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** <!-- Please add some reference about how your feature has been tested. --> **Checklist** <!-- Please go over the list and make sure you've taken everything into account --> - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 1bbf543 commit b7ac946

File tree

14 files changed

+263
-245
lines changed

14 files changed

+263
-245
lines changed

argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ async def _build_response_status_filter_for_search(
236236
return user_response_status_filter
237237

238238

239-
async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID):
239+
async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset):
240240
try:
241-
await search.validate_search_records_query(db, query, dataset_id)
241+
await search.validate_search_records_query(db, query, dataset)
242242
except (ValueError, NotFoundError) as e:
243243
raise UnprocessableEntityError(str(e))
244244

@@ -324,7 +324,7 @@ async def search_current_user_dataset_records(
324324

325325
await authorize(current_user, DatasetPolicy.search_records(dataset))
326326

327-
await _validate_search_records_query(db, body, dataset_id)
327+
await _validate_search_records_query(db, body, dataset)
328328

329329
search_responses = await _get_search_responses(
330330
db=db,
@@ -385,7 +385,7 @@ async def search_dataset_records(
385385

386386
await authorize(current_user, DatasetPolicy.search_records_with_all_responses(dataset))
387387

388-
await _validate_search_records_query(db, body, dataset_id)
388+
await _validate_search_records_query(db, body, dataset)
389389

390390
search_responses = await _get_search_responses(
391391
db=db,

argilla-server/src/argilla_server/api/schemas/v1/responses.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,6 @@ class ResponseCreate(BaseModel):
102102
values: Optional[ResponseValuesCreate]
103103
status: ResponseStatus
104104

105-
@property
106-
def is_submitted(self):
107-
return self.status == ResponseStatus.submitted
108-
109105

110106
class ResponseFilterScope(BaseModel):
111107
entity: Literal["response"]

argilla-server/src/argilla_server/bulk/records_bulk.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, db: AsyncSession, search_engine: SearchEngine):
5050
self._search_engine = search_engine
5151

5252
async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCreate) -> RecordsBulk:
53-
await RecordsBulkCreateValidator(bulk_create, db=self._db).validate_for(dataset)
53+
await RecordsBulkCreateValidator.validate(self._db, bulk_create, dataset)
5454

5555
async with self._db.begin_nested():
5656
records = [
@@ -98,7 +98,7 @@ async def _upsert_records_suggestions(
9898
raise ValueError(f"question with question_id={suggestion_create.question_id} does not exist")
9999

100100
try:
101-
SuggestionCreateValidator(suggestion_create).validate_for(question.parsed_settings, record)
101+
SuggestionCreateValidator.validate(suggestion_create, question.parsed_settings, record)
102102
upsert_many_suggestions.append(dict(**suggestion_create.dict(), record_id=record.id))
103103
except (UnprocessableEntityError, ValueError) as ex:
104104
raise ValueError(f"suggestion for question name={question.name} is not valid: {ex}")
@@ -131,7 +131,7 @@ async def _upsert_records_responses(
131131
if response_create.user_id not in users_by_id:
132132
raise ValueError(f"user with id {response_create.user_id} not found")
133133

134-
ResponseCreateValidator(response_create).validate_for(record)
134+
ResponseCreateValidator.validate(response_create, record)
135135
upsert_many_responses.append(dict(**response_create.dict(), record_id=record.id))
136136
except (UnprocessableEntityError, ValueError) as ex:
137137
raise UnprocessableEntityError(
@@ -159,7 +159,7 @@ async def _upsert_records_vectors(
159159
if not settings:
160160
raise ValueError(f"vector with name={name} does not exist for dataset_id={record.dataset.id}")
161161

162-
VectorValidator(value).validate_for(settings)
162+
VectorValidator.validate(value, settings)
163163
upsert_many_vectors.append(dict(value=value, record_id=record.id, vector_settings_id=settings.id))
164164
except (UnprocessableEntityError, ValueError) as ex:
165165
raise UnprocessableEntityError(
@@ -185,7 +185,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp
185185
found_records = await self._fetch_existing_dataset_records(dataset, bulk_upsert.items)
186186
# found_records is passed to the validator to avoid querying the database again, but ideally, it should be
187187
# computed inside the validator
188-
RecordsBulkUpsertValidator(bulk_upsert, self._db, found_records).validate_for(dataset)
188+
RecordsBulkUpsertValidator.validate(bulk_upsert, dataset, found_records)
189189

190190
records = []
191191
async with self._db.begin_nested():

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ async def _build_record_responses(
615615
try:
616616
cache = await validate_user_exists(db, response_create.user_id, cache)
617617

618-
ResponseCreateValidator(response_create).validate_for(record)
618+
ResponseCreateValidator.validate(response_create, record)
619619

620620
responses.append(
621621
Response(
@@ -656,7 +656,7 @@ async def _build_record_suggestions(
656656
raise UnprocessableEntityError(f"question_id={str(suggestion_create.question_id)} does not exist")
657657
questions_cache[suggestion_create.question_id] = question
658658

659-
SuggestionCreateValidator(suggestion_create).validate_for(question.parsed_settings, record)
659+
SuggestionCreateValidator.validate(suggestion_create, question.parsed_settings, record)
660660

661661
suggestions.append(
662662
Suggestion(
@@ -840,7 +840,7 @@ async def create_response(
840840
f"Response already exists for record with id `{record.id}` and by user with id `{user.id}`"
841841
)
842842

843-
ResponseCreateValidator(response_create).validate_for(record)
843+
ResponseCreateValidator.validate(response_create, record)
844844

845845
async with db.begin_nested():
846846
response = await Response.create(
@@ -866,7 +866,7 @@ async def create_response(
866866
async def update_response(
867867
db: AsyncSession, search_engine: SearchEngine, response: Response, response_update: ResponseUpdate
868868
):
869-
ResponseUpdateValidator(response_update).validate_for(response.record)
869+
ResponseUpdateValidator.validate(response_update, response.record)
870870

871871
async with db.begin_nested():
872872
response = await response.update(
@@ -890,7 +890,7 @@ async def update_response(
890890
async def upsert_response(
891891
db: AsyncSession, search_engine: SearchEngine, record: Record, user: User, response_upsert: ResponseUpsert
892892
) -> Response:
893-
ResponseUpsertValidator(response_upsert).validate_for(record)
893+
ResponseUpsertValidator.validate(response_upsert, record)
894894

895895
async with db.begin_nested():
896896
response = await Response.upsert(
@@ -963,7 +963,7 @@ async def upsert_suggestion(
963963
question: Question,
964964
suggestion_create: "SuggestionCreate",
965965
) -> Suggestion:
966-
SuggestionCreateValidator(suggestion_create).validate_for(question.parsed_settings, record)
966+
SuggestionCreateValidator.validate(suggestion_create, question.parsed_settings, record)
967967

968968
async with db.begin_nested():
969969
suggestion = await Suggestion.upsert(

argilla-server/src/argilla_server/contexts/questions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def create_question(db: AsyncSession, dataset: Dataset, question_create: Q
3434
f"Question with name `{question_create.name}` already exists for dataset with id `{dataset.id}`"
3535
)
3636

37-
QuestionCreateValidator(question_create).validate_for(dataset)
37+
QuestionCreateValidator.validate(question_create, dataset)
3838

3939
return await Question.create(
4040
db,
@@ -48,14 +48,14 @@ async def create_question(db: AsyncSession, dataset: Dataset, question_create: Q
4848

4949

5050
async def update_question(db: AsyncSession, question: Question, question_update: QuestionUpdate) -> Question:
51-
QuestionUpdateValidator(question_update).validate_for(question)
51+
QuestionUpdateValidator.validate(question_update, question)
5252

5353
params = question_update.dict(exclude_unset=True)
5454

5555
return await question.update(db, **params)
5656

5757

5858
async def delete_question(db: AsyncSession, question: Question) -> Question:
59-
QuestionDeleteValidator().validate_for(question.dataset)
59+
QuestionDeleteValidator.validate(question.dataset)
6060

6161
return await question.delete(db)

argilla-server/src/argilla_server/contexts/search.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,55 +26,61 @@
2626
)
2727
from argilla_server.api.schemas.v1.responses import ResponseFilterScope
2828
from argilla_server.api.schemas.v1.suggestions import SuggestionFilterScope
29-
from argilla_server.models import MetadataProperty, Question, Suggestion
29+
from argilla_server.models import MetadataProperty, Question, Suggestion, Dataset
3030

3131

3232
class SearchRecordsQueryValidator:
33-
def __init__(self, db: AsyncSession, query: SearchRecordsQuery, dataset_id: UUID):
34-
self._db = db
35-
self._query = query
36-
self._dataset_id = dataset_id
37-
38-
async def validate(self) -> None:
39-
if self._query.filters:
40-
for filter in self._query.filters.and_:
41-
await self._validate_filter_scope(filter.scope)
42-
43-
if self._query.sort:
44-
for order in self._query.sort:
45-
await self._validate_filter_scope(order.scope)
46-
47-
async def _validate_filter_scope(self, filter_scope: FilterScope) -> None:
33+
@classmethod
34+
async def validate(cls, db: AsyncSession, dataset: Dataset, query: SearchRecordsQuery) -> None:
35+
if query.filters:
36+
for filter in query.filters.and_:
37+
await cls._validate_filter_scope(db, dataset, filter.scope)
38+
39+
if query.sort:
40+
for order in query.sort:
41+
await cls._validate_filter_scope(db, dataset, order.scope)
42+
43+
@classmethod
44+
async def _validate_filter_scope(cls, db: AsyncSession, dataset: Dataset, filter_scope: FilterScope) -> None:
4845
if isinstance(filter_scope, RecordFilterScope):
4946
return
5047
elif isinstance(filter_scope, ResponseFilterScope):
51-
await self._validate_response_filter_scope(filter_scope)
48+
await cls._validate_response_filter_scope(db, dataset, filter_scope)
5249
elif isinstance(filter_scope, SuggestionFilterScope):
53-
await self._validate_suggestion_filter_scope(filter_scope)
50+
await cls._validate_suggestion_filter_scope(db, dataset, filter_scope)
5451
elif isinstance(filter_scope, MetadataFilterScope):
55-
await self._validate_metadata_filter_scope(filter_scope)
52+
await cls._validate_metadata_filter_scope(db, dataset, filter_scope)
5653
else:
5754
raise ValueError(f"Unknown filter scope entity `{filter_scope.entity}`")
5855

59-
async def _validate_response_filter_scope(self, filter_scope: ResponseFilterScope) -> None:
56+
@staticmethod
57+
async def _validate_response_filter_scope(
58+
db: AsyncSession, dataset: Dataset, filter_scope: ResponseFilterScope
59+
) -> None:
6060
if filter_scope.question is None:
6161
return
6262

63-
await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id)
63+
await Question.get_by_or_raise(db, name=filter_scope.question, dataset_id=dataset.id)
6464

65-
async def _validate_suggestion_filter_scope(self, filter_scope: SuggestionFilterScope) -> None:
66-
await Question.get_by_or_raise(self._db, name=filter_scope.question, dataset_id=self._dataset_id)
65+
@staticmethod
66+
async def _validate_suggestion_filter_scope(
67+
db: AsyncSession, dataset: Dataset, filter_scope: SuggestionFilterScope
68+
) -> None:
69+
await Question.get_by_or_raise(db, name=filter_scope.question, dataset_id=dataset.id)
6770

68-
async def _validate_metadata_filter_scope(self, filter_scope: MetadataFilterScope) -> None:
71+
@staticmethod
72+
async def _validate_metadata_filter_scope(
73+
db: AsyncSession, dataset: Dataset, filter_scope: MetadataFilterScope
74+
) -> None:
6975
await MetadataProperty.get_by_or_raise(
70-
self._db,
76+
db,
7177
name=filter_scope.metadata_property,
72-
dataset_id=self._dataset_id,
78+
dataset_id=dataset.id,
7379
)
7480

7581

76-
async def validate_search_records_query(db: AsyncSession, query: SearchRecordsQuery, dataset_id: UUID) -> None:
77-
await SearchRecordsQueryValidator(db, query, dataset_id).validate()
82+
async def validate_search_records_query(db: AsyncSession, query: SearchRecordsQuery, dataset: Dataset) -> None:
83+
await SearchRecordsQueryValidator.validate(db, dataset, query)
7884

7985

8086
async def get_dataset_suggestion_agents_by_question(db: AsyncSession, dataset_id: UUID) -> List[Mapping[str, Any]]:

argilla-server/src/argilla_server/validators/questions.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@
2727

2828

2929
class QuestionCreateValidator:
30-
def __init__(self, question_create: QuestionCreate):
31-
self._question_create = question_create
30+
@classmethod
31+
def validate(cls, question_create: QuestionCreate, dataset: Dataset):
32+
cls._validate_dataset_is_not_ready(dataset)
33+
cls._validate_span_question_settings(question_create, dataset)
3234

33-
def validate_for(self, dataset: Dataset):
34-
self._validate_dataset_is_not_ready(dataset)
35-
self._validate_span_question_settings(dataset)
36-
37-
def _validate_dataset_is_not_ready(self, dataset):
35+
@staticmethod
36+
def _validate_dataset_is_not_ready(dataset):
3837
if dataset.is_ready:
3938
raise UnprocessableEntityError("questions cannot be created for a published dataset")
4039

41-
def _validate_span_question_settings(self, dataset: Dataset):
42-
if self._question_create.settings.type != QuestionType.span:
40+
@classmethod
41+
def _validate_span_question_settings(cls, question_create: QuestionCreate, dataset: Dataset):
42+
if question_create.settings.type != QuestionType.span:
4343
return
4444

45-
field = self._question_create.settings.field
45+
field = question_create.settings.field
4646
field_names = [field.name for field in dataset.fields]
4747

4848
if field not in field_names:
@@ -68,41 +68,44 @@ class QuestionUpdateValidator:
6868
QuestionType.span,
6969
]
7070

71-
def __init__(self, question_update: QuestionUpdate):
72-
self._question_update = question_update
73-
74-
def validate_for(self, question: Question):
75-
self._validate_question_settings(question.parsed_settings)
71+
@classmethod
72+
def validate(cls, question_update: QuestionUpdate, question: Question):
73+
cls._validate_question_settings(question_update, question.parsed_settings)
7674

77-
def _validate_question_settings(self, question_settings: QuestionSettings):
78-
if not self._question_update.settings:
75+
@classmethod
76+
def _validate_question_settings(cls, question_update: QuestionUpdate, question_settings: QuestionSettings):
77+
if not question_update.settings:
7978
return
8079

81-
self._validate_question_settings_type_is_the_same(question_settings, self._question_update.settings)
82-
self._validate_question_settings_label_options(question_settings, self._question_update.settings)
83-
self._validate_question_settings_visible_options(question_settings, self._question_update.settings)
84-
self._validate_span_question_settings(question_settings, self._question_update.settings)
80+
cls._validate_question_settings_type_is_the_same(question_settings, question_update.settings)
81+
cls._validate_question_settings_label_options(question_settings, question_update.settings)
82+
cls._validate_question_settings_visible_options(question_settings, question_update.settings)
83+
cls._validate_span_question_settings(question_settings, question_update.settings)
8584

85+
@staticmethod
8686
def _validate_question_settings_type_is_the_same(
87-
self, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
87+
question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
8888
):
8989
if question_settings.type != question_settings_update.type:
9090
raise UnprocessableEntityError(
91-
f"question type cannot be changed. expected '{question_settings.type}' but got '{question_settings_update.type}'"
91+
"question type cannot be changed. "
92+
f"expected '{question_settings.type}' but got '{question_settings_update.type}'"
9293
)
9394

95+
@classmethod
9496
def _validate_question_settings_label_options(
95-
self, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
97+
cls, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
9698
):
97-
if question_settings.type not in self.QUESTION_TYPES_WITH_LABEL_OPTIONS:
99+
if question_settings.type not in cls.QUESTION_TYPES_WITH_LABEL_OPTIONS:
98100
return
99101

100102
if question_settings_update.options is None:
101103
return
102104

103105
if len(question_settings.options) != len(question_settings_update.options):
104106
raise UnprocessableEntityError(
105-
f"the number of options cannot be modified. expected {len(question_settings.options)} but got {len(question_settings_update.options)}"
107+
"the number of options cannot be modified. "
108+
f"expected {len(question_settings.options)} but got {len(question_settings_update.options)}"
106109
)
107110

108111
sorted_options = sorted(question_settings.options, key=lambda option: option.value)
@@ -118,10 +121,11 @@ def _validate_question_settings_label_options(
118121
f"the option values cannot be modified. found unexpected option values: {unexpected_options!r}"
119122
)
120123

124+
@classmethod
121125
def _validate_question_settings_visible_options(
122-
self, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
126+
cls, question_settings: QuestionSettings, question_settings_update: QuestionSettingsUpdate
123127
):
124-
if question_settings_update.type not in self.QUESTION_TYPES_WITH_VISIBLE_OPTIONS:
128+
if question_settings_update.type not in cls.QUESTION_TYPES_WITH_VISIBLE_OPTIONS:
125129
return
126130

127131
if question_settings_update.visible_options is None:
@@ -130,11 +134,13 @@ def _validate_question_settings_visible_options(
130134
number_of_options = len(question_settings.options)
131135
if question_settings_update.visible_options > number_of_options:
132136
raise UnprocessableEntityError(
133-
f"the value for 'visible_options' must be less or equal to the number of items in 'options' ({number_of_options})"
137+
"the value for 'visible_options' must be less or equal to "
138+
f"the number of items in 'options' ({number_of_options})"
134139
)
135140

141+
@staticmethod
136142
def _validate_span_question_settings(
137-
self, question_settings: SpanQuestionSettings, question_settings_update: QuestionSettingsUpdate
143+
question_settings: SpanQuestionSettings, question_settings_update: QuestionSettingsUpdate
138144
) -> None:
139145
if question_settings_update.type != QuestionType.span:
140146
return
@@ -146,9 +152,11 @@ def _validate_span_question_settings(
146152

147153

148154
class QuestionDeleteValidator:
149-
def validate_for(self, dataset: Dataset):
150-
self._validate_dataset_is_not_ready(dataset)
155+
@classmethod
156+
def validate(cls, dataset: Dataset):
157+
cls._validate_dataset_is_not_ready(dataset)
151158

152-
def _validate_dataset_is_not_ready(self, dataset):
159+
@staticmethod
160+
def _validate_dataset_is_not_ready(dataset):
153161
if dataset.is_ready:
154162
raise UnprocessableEntityError("questions cannot be deleted for a published dataset")

0 commit comments

Comments
 (0)