Skip to content

Commit 1acbaec

Browse files
committed
Update score set view models to include update model with all optional fields and class method to handle multi-part form data
1 parent d493d12 commit 1acbaec

File tree

1 file changed

+96
-6
lines changed

1 file changed

+96
-6
lines changed

src/mavedb/view_models/score_set.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
from __future__ import annotations
33

44
from datetime import date
5-
from typing import Any, Collection, Optional, Sequence, Union
5+
import json
6+
from typing import Any, Collection, Optional, Sequence, Union, Type, TypeVar, Callable
67
from typing_extensions import Self
8+
from copy import deepcopy
79

810
from humps import camelize
9-
from pydantic import field_validator, model_validator
11+
from pydantic import field_validator, model_validator, create_model
12+
from pydantic.fields import FieldInfo
13+
from fastapi import Form
1014

1115
from mavedb.lib.validation import urn_re
1216
from mavedb.lib.validation.exceptions import ValidationError
@@ -45,6 +49,44 @@
4549

4650
UnboundedRange = tuple[Union[float, None], Union[float, None]]
4751

52+
Model = TypeVar("Model", bound=Type[BaseModel])
53+
54+
def partial_model(exclude_fields: Optional[list[str]] = None) -> Callable[[Model], Model]:
55+
"""A decorator that create a partial model.
56+
57+
Args:
58+
model (Type[BaseModel]): BaseModel model.
59+
60+
Returns:
61+
Type[BaseModel]: ModelBase partial model.
62+
"""
63+
if exclude_fields is None:
64+
exclude_fields = []
65+
66+
def wrapper(model: Type[Model]) -> Type[Model]:
67+
base_model: Type[Model] = model
68+
69+
def make_field_optional(field: FieldInfo, default: Any = None) -> tuple[Any, FieldInfo]:
70+
new = deepcopy(field)
71+
new.default = default
72+
new.annotation = Optional[field.annotation]
73+
return new.annotation, new
74+
75+
if exclude_fields:
76+
base_model = BaseModel
77+
78+
return create_model(
79+
model.__name__,
80+
__base__=base_model,
81+
__module__=model.__module__,
82+
**{
83+
field_name: make_field_optional(field_info)
84+
for field_name, field_info in model.model_fields.items()
85+
if field_name not in exclude_fields
86+
},
87+
)
88+
89+
return wrapper
4890

4991
class ExternalLink(BaseModel):
5092
url: Optional[str] = None
@@ -70,15 +112,17 @@ class ScoreSetBase(BaseModel):
70112
extra_metadata: Optional[dict] = None
71113
data_usage_policy: Optional[str] = None
72114

73-
74-
class ScoreSetModify(ScoreSetBase):
115+
class ScoreSetModifyBase(ScoreSetBase):
75116
contributors: Optional[list[ContributorCreate]] = None
76117
primary_publication_identifiers: Optional[list[PublicationIdentifierCreate]] = None
77118
secondary_publication_identifiers: Optional[list[PublicationIdentifierCreate]] = None
78119
doi_identifiers: Optional[list[DoiIdentifierCreate]] = None
79120
target_genes: list[TargetGeneCreate]
80121
score_ranges: Optional[ScoreSetRangesCreate] = None
81122

123+
class ScoreSetModify(ScoreSetModifyBase):
124+
"""View model that adds custom validators to ScoreSetModifyBase."""
125+
82126
@field_validator("title", "short_description", "abstract_text", "method_text")
83127
def validate_field_is_non_empty(cls, v: str) -> str:
84128
if is_null(v):
@@ -89,7 +133,7 @@ def validate_field_is_non_empty(cls, v: str) -> str:
89133
def max_one_primary_publication_identifier(
90134
cls, v: list[PublicationIdentifierCreate]
91135
) -> list[PublicationIdentifierCreate]:
92-
if len(v) > 1:
136+
if v is not None and len(v) > 1:
93137
raise ValidationError("Multiple primary publication identifiers are not allowed.")
94138
return v
95139

@@ -270,12 +314,58 @@ def validate_experiment_urn_required_except_for_meta_analyses(self) -> Self:
270314
raise ValidationError("experiment URN should not be supplied when your score set is a meta-analysis")
271315
return self
272316

317+
class ScoreSetUpdateBase(ScoreSetModifyBase):
318+
"""View model for updating a score set with no custom validators."""
319+
320+
license_id: Optional[int] = None
273321

274322
class ScoreSetUpdate(ScoreSetModify):
275-
"""View model for updating a score set."""
323+
"""View model for updating a score set that includes custom validators."""
276324

277325
license_id: Optional[int] = None
278326

327+
328+
@partial_model()
329+
class ScoreSetUpdateAllOptional(ScoreSetUpdateBase):
330+
@classmethod
331+
def as_form(
332+
cls,
333+
334+
# ScoreSetBase fields
335+
title: Optional[str] = Form(None),
336+
method_text: Optional[str] = Form(None),
337+
abstract_text: Optional[str] = Form(None),
338+
short_description: Optional[str] = Form(None),
339+
extra_metadata: Optional[str] = Form(None),
340+
data_usage_policy: Optional[str] = Form(None),
341+
342+
# ScoreSetModify fields
343+
contributors: Optional[str] = Form(None),
344+
primary_publication_identifiers: Optional[str] = Form(None),
345+
secondary_publication_identifiers: Optional[str] = Form(None),
346+
doi_identifiers: Optional[str] = Form(None),
347+
target_genes: Optional[str] = Form(None),
348+
score_ranges: Optional[str] = Form(None),
349+
350+
# ScoreSetUpdate fields
351+
license_id: Optional[int] = Form(None),
352+
) -> "ScoreSetUpdateAllOptional":
353+
return cls(
354+
title=title,
355+
method_text=method_text,
356+
abstract_text=abstract_text,
357+
short_description=short_description,
358+
extra_metadata=json.loads(extra_metadata) if extra_metadata else None,
359+
data_usage_policy=data_usage_policy,
360+
contributors=[ContributorCreate.model_validate(c) for c in json.loads(contributors)] if contributors else None,
361+
primary_publication_identifiers=[PublicationIdentifierCreate.model_validate(p) for p in json.loads(primary_publication_identifiers)] if primary_publication_identifiers else None,
362+
secondary_publication_identifiers=[PublicationIdentifierCreate.model_validate(s) for s in json.loads(secondary_publication_identifiers)] if secondary_publication_identifiers else None,
363+
doi_identifiers=[DoiIdentifierCreate.model_validate(d) for d in json.loads(doi_identifiers)] if doi_identifiers else None,
364+
target_genes=[TargetGeneCreate.model_validate(t) for t in json.loads(target_genes)] if target_genes else None,
365+
score_ranges=ScoreSetRangesCreate.model_validate(json.loads(score_ranges)) if score_ranges else None,
366+
license_id=license_id,
367+
)
368+
279369
class DatasetColumnMetadata(BaseModel):
280370
"""Metadata for individual dataset columns."""
281371

0 commit comments

Comments
 (0)