Skip to content

Commit cd9f7d1

Browse files
authored
feat: add better validation for SchedulerConfig with edit_mode() (#8)
Signed-off-by: Will Killian <william.killian@outlook.com>
1 parent 6979b89 commit cd9f7d1

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

src/scheduler/config.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from contextlib import contextmanager
12
from enum import StrEnum
23
from typing import Annotated, Literal
34

@@ -127,9 +128,9 @@ class _StrictBaseModel(BaseModel):
127128
- model_config: Configuration for the model
128129
"""
129130

130-
model_config = ConfigDict(extra="forbid", strict=True)
131+
model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True)
131132
"""
132-
Configuration for the model which forbids extra fields and is strict (@private)
133+
Configuration for the model which forbids extra fields, is strict, and validates on assignment (@private)
133134
"""
134135

135136

@@ -352,12 +353,6 @@ class CourseConfig(_StrictBaseModel):
352353
List of faculty names
353354
"""
354355

355-
@model_validator(mode="after")
356-
def _validate_references(self):
357-
"""Validate that all references exist in the parent SchedulerConfig.
358-
This validator will be called by the parent SchedulerConfig."""
359-
return self
360-
361356

362357
class FacultyConfig(_StrictBaseModel):
363358
"""
@@ -505,10 +500,19 @@ class SchedulerConfig(_StrictBaseModel):
505500
"""
506501

507502
@model_validator(mode="after")
508-
def _validate_all_references(self):
503+
def validate_references(self):
509504
"""
510-
Validate that all `Faculty`, `Room`, `Lab`, and `Course` references exist.
511-
Validate that all `Faculty`, `Room`, and `Lab` definitions are unique.
505+
Validate all cross-references between child models.
506+
This method can be called manually or is used by Pydantic validators.
507+
508+
**Usage:**
509+
```python
510+
config.courses[0].room = ["NewRoom"]
511+
config.validate_references() # Validates all cross-references
512+
```
513+
514+
**Raises:**
515+
- ValueError: If any cross-reference validation fails
512516
"""
513517
# Validate uniqueness first
514518
self._validate_uniqueness()
@@ -577,6 +581,31 @@ def _validate_all_references(self):
577581

578582
return self
579583

584+
@contextmanager
585+
def edit_mode(self):
586+
"""
587+
Context manager for making multiple changes with automatic rollback on validation failure.
588+
589+
**Usage:**
590+
```python
591+
with config.edit_mode() as editable_config:
592+
editable_config.courses[0].room = ["NewRoom"]
593+
editable_config.courses[0].faculty = ["NewFaculty"]
594+
editable_config.rooms.append("AnotherRoom")
595+
# If validation fails, changes are automatically rolled back
596+
```
597+
598+
**Raises:**
599+
- ValueError: If any cross-reference validation fails (with automatic rollback)
600+
"""
601+
# Create a working copy for editing
602+
working_copy = self.model_copy(deep=True)
603+
yield working_copy
604+
# Validate the working copy
605+
working_copy.validate_references()
606+
# If validation passes, update the original object
607+
self.__dict__.update(working_copy.__dict__)
608+
580609
def _validate_business_logic(self, errors: list[str]) -> "SchedulerConfig":
581610
"""
582611
Validate business logic constraints.

0 commit comments

Comments
 (0)