Skip to content

Commit 6c5d5e9

Browse files
authored
refactor: add DatasetPublishValidator class (#5568)
# Description A small refactor moving validation logic for `publish_dataset` context function to `DatasetPublishValidator` class. Once we merge this PR then I can do changes to the validations so we can increase flexibility to import datasets from the Hub. I was bored of creating "helper functions" for counting queries for database so I have added `count_by` function into our database models mixin class. I hope it can be useful. **Type of change** - Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** - [x] Running test suite. **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)
1 parent 9677435 commit 6c5d5e9

File tree

4 files changed

+34
-24
lines changed

4 files changed

+34
-24
lines changed

argilla-server/src/argilla_server/contexts/datasets.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
)
8282
from argilla_server.models.suggestions import SuggestionCreateWithRecordId
8383
from argilla_server.search_engine import SearchEngine
84-
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator
84+
from argilla_server.validators.datasets import DatasetCreateValidator, DatasetPublishValidator, DatasetUpdateValidator
8585
from argilla_server.validators.responses import (
8686
ResponseCreateValidator,
8787
ResponseUpdateValidator,
@@ -145,16 +145,6 @@ async def create_dataset(db: AsyncSession, dataset_attrs: dict):
145145
return await dataset.save(db)
146146

147147

148-
async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int:
149-
return (await db.execute(select(func.count(Field.id)).filter_by(dataset_id=dataset_id, required=True))).scalar_one()
150-
151-
152-
async def _count_required_questions_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int:
153-
return (
154-
await db.execute(select(func.count(Question.id)).filter_by(dataset_id=dataset_id, required=True))
155-
).scalar_one()
156-
157-
158148
def _allowed_roles_for_metadata_property_create(metadata_property_create: MetadataPropertyCreate) -> List[UserRole]:
159149
if metadata_property_create.visible_for_annotators:
160150
return VISIBLE_FOR_ANNOTATORS_ALLOWED_ROLES
@@ -163,14 +153,7 @@ def _allowed_roles_for_metadata_property_create(metadata_property_create: Metada
163153

164154

165155
async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset:
166-
if dataset.is_ready:
167-
raise UnprocessableEntityError("Dataset is already published")
168-
169-
if await _count_required_fields_by_dataset_id(db, dataset.id) == 0:
170-
raise UnprocessableEntityError("Dataset cannot be published without required fields")
171-
172-
if await _count_required_questions_by_dataset_id(db, dataset.id) == 0:
173-
raise UnprocessableEntityError("Dataset cannot be published without required questions")
156+
await DatasetPublishValidator.validate(db, dataset)
174157

175158
async with db.begin_nested():
176159
dataset = await dataset.update(db, status=DatasetStatus.ready, autocommit=False)

argilla-server/src/argilla_server/models/mixins.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import TYPE_CHECKING, Any, Dict, List, Set, TypeVar, Union
1717
from uuid import UUID
1818

19-
from sqlalchemy import select, sql
19+
from sqlalchemy import select, func, sql
2020
from sqlalchemy.dialects.mysql import insert as mysql_insert
2121
from sqlalchemy.dialects.postgresql import insert as postgres_insert
2222
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
@@ -100,6 +100,10 @@ async def get_by_or_raise(cls, db: AsyncSession, **conditions) -> Self:
100100

101101
raise NotFoundError(f"{cls.__name__} not found filtering by {conditions_str}")
102102

103+
@classmethod
104+
async def count_by(cls, db: AsyncSession, **conditions) -> int:
105+
return (await db.execute(select(func.count(cls.id)).filter_by(**conditions))).scalar_one()
106+
103107
async def update(
104108
self,
105109
db: AsyncSession,

argilla-server/src/argilla_server/validators/datasets.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
from sqlalchemy.ext.asyncio import AsyncSession
1818

19+
from argilla_server.models import Dataset, Field, Question, Workspace
1920
from argilla_server.errors.future import (
2021
NotUniqueError,
2122
UnprocessableEntityError,
2223
UpdateDistributionWithExistingResponsesError,
2324
)
24-
from argilla_server.models import Dataset, Workspace
2525

2626

2727
class DatasetCreateValidator:
@@ -41,6 +41,29 @@ async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, wor
4141
raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`")
4242

4343

44+
class DatasetPublishValidator:
45+
@classmethod
46+
async def validate(cls, db: AsyncSession, dataset: Dataset) -> None:
47+
await cls._validate_has_not_been_published_yet(db, dataset)
48+
await cls._validate_has_at_least_one_required_field(db, dataset)
49+
await cls._validate_has_at_least_one_required_question(db, dataset)
50+
51+
@classmethod
52+
async def _validate_has_not_been_published_yet(cls, db: AsyncSession, dataset: Dataset) -> None:
53+
if dataset.is_ready:
54+
raise UnprocessableEntityError("Dataset has already been published")
55+
56+
@classmethod
57+
async def _validate_has_at_least_one_required_field(cls, db: AsyncSession, dataset: Dataset) -> None:
58+
if await Field.count_by(db, dataset_id=dataset.id, required=True) == 0:
59+
raise UnprocessableEntityError("Dataset cannot be published without required fields")
60+
61+
@classmethod
62+
async def _validate_has_at_least_one_required_question(cls, db: AsyncSession, dataset: Dataset) -> None:
63+
if await Question.count_by(db, dataset_id=dataset.id, required=True) == 0:
64+
raise UnprocessableEntityError("Dataset cannot be published without required questions")
65+
66+
4467
class DatasetUpdateValidator:
4568
@classmethod
4669
async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None:

argilla-server/tests/unit/api/handlers/v1/test_datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4672,10 +4672,10 @@ async def test_publish_dataset_already_published(
46724672
response = await async_client.put(f"/api/v1/datasets/{dataset.id}/publish", headers=owner_auth_header)
46734673

46744674
assert response.status_code == 422
4675-
assert response.json() == {"detail": "Dataset is already published"}
4675+
assert response.json() == {"detail": "Dataset has already been published"}
46764676
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0
46774677

4678-
async def test_publish_dataset_without_fields(
4678+
async def test_publish_dataset_without_required_fields(
46794679
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
46804680
):
46814681
dataset = await DatasetFactory.create()
@@ -4688,7 +4688,7 @@ async def test_publish_dataset_without_fields(
46884688
assert response.json() == {"detail": "Dataset cannot be published without required fields"}
46894689
assert (await db.execute(select(func.count(Record.id)))).scalar() == 0
46904690

4691-
async def test_publish_dataset_without_questions(
4691+
async def test_publish_dataset_without_required_questions(
46924692
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
46934693
):
46944694
dataset = await DatasetFactory.create()

0 commit comments

Comments
 (0)