|
| 1 | +from contextlib import contextmanager |
1 | 2 | from enum import StrEnum |
2 | 3 | from typing import Annotated, Literal |
3 | 4 |
|
@@ -127,9 +128,9 @@ class _StrictBaseModel(BaseModel): |
127 | 128 | - model_config: Configuration for the model |
128 | 129 | """ |
129 | 130 |
|
130 | | - model_config = ConfigDict(extra="forbid", strict=True) |
| 131 | + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) |
131 | 132 | """ |
132 | | - Configuration for the model which forbids extra fields and is strict (@private) |
| 133 | + Configuration for the model which forbids extra fields, is strict, and validates on assignment (@private) |
133 | 134 | """ |
134 | 135 |
|
135 | 136 |
|
@@ -352,12 +353,6 @@ class CourseConfig(_StrictBaseModel): |
352 | 353 | List of faculty names |
353 | 354 | """ |
354 | 355 |
|
355 | | - @model_validator(mode="after") |
356 | | - def _validate_references(self): |
357 | | - """Validate that all references exist in the parent SchedulerConfig. |
358 | | - This validator will be called by the parent SchedulerConfig.""" |
359 | | - return self |
360 | | - |
361 | 356 |
|
362 | 357 | class FacultyConfig(_StrictBaseModel): |
363 | 358 | """ |
@@ -505,10 +500,19 @@ class SchedulerConfig(_StrictBaseModel): |
505 | 500 | """ |
506 | 501 |
|
507 | 502 | @model_validator(mode="after") |
508 | | - def _validate_all_references(self): |
| 503 | + def validate_references(self): |
509 | 504 | """ |
510 | | - Validate that all `Faculty`, `Room`, `Lab`, and `Course` references exist. |
511 | | - Validate that all `Faculty`, `Room`, and `Lab` definitions are unique. |
| 505 | + Validate all cross-references between child models. |
| 506 | + This method can be called manually or is used by Pydantic validators. |
| 507 | +
|
| 508 | + **Usage:** |
| 509 | + ```python |
| 510 | + config.courses[0].room = ["NewRoom"] |
| 511 | + config.validate_references() # Validates all cross-references |
| 512 | + ``` |
| 513 | +
|
| 514 | + **Raises:** |
| 515 | + - ValueError: If any cross-reference validation fails |
512 | 516 | """ |
513 | 517 | # Validate uniqueness first |
514 | 518 | self._validate_uniqueness() |
@@ -577,6 +581,31 @@ def _validate_all_references(self): |
577 | 581 |
|
578 | 582 | return self |
579 | 583 |
|
| 584 | + @contextmanager |
| 585 | + def edit_mode(self): |
| 586 | + """ |
| 587 | + Context manager for making multiple changes with automatic rollback on validation failure. |
| 588 | +
|
| 589 | + **Usage:** |
| 590 | + ```python |
| 591 | + with config.edit_mode() as editable_config: |
| 592 | + editable_config.courses[0].room = ["NewRoom"] |
| 593 | + editable_config.courses[0].faculty = ["NewFaculty"] |
| 594 | + editable_config.rooms.append("AnotherRoom") |
| 595 | + # If validation fails, changes are automatically rolled back |
| 596 | + ``` |
| 597 | +
|
| 598 | + **Raises:** |
| 599 | + - ValueError: If any cross-reference validation fails (with automatic rollback) |
| 600 | + """ |
| 601 | + # Create a working copy for editing |
| 602 | + working_copy = self.model_copy(deep=True) |
| 603 | + yield working_copy |
| 604 | + # Validate the working copy |
| 605 | + working_copy.validate_references() |
| 606 | + # If validation passes, update the original object |
| 607 | + self.__dict__.update(working_copy.__dict__) |
| 608 | + |
580 | 609 | def _validate_business_logic(self, errors: list[str]) -> "SchedulerConfig": |
581 | 610 | """ |
582 | 611 | Validate business logic constraints. |
|
0 commit comments