Skip to content

Commit 539f844

Browse files
authored
extend edit_mode to more config classes (#9)
* feat: extend `edit_mode` to more config classes Signed-off-by: Will Killian <william.killian@outlook.com>
1 parent cd9f7d1 commit 539f844

File tree

1 file changed

+77
-41
lines changed

1 file changed

+77
-41
lines changed

src/scheduler/config.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ConfigDict,
88
Field,
99
PositiveInt,
10+
ValidationError,
1011
field_validator,
1112
model_validator,
1213
)
@@ -133,6 +134,34 @@ class _StrictBaseModel(BaseModel):
133134
Configuration for the model which forbids extra fields, is strict, and validates on assignment (@private)
134135
"""
135136

137+
@contextmanager
138+
def edit_mode(self):
139+
"""
140+
Context manager for making multiple changes with automatic rollback on validation failure.
141+
142+
**Usage:**
143+
```python
144+
with config.edit_mode() as editable_config:
145+
editable_config.some_field = "new_value"
146+
editable_config.another_field.append("item")
147+
# If validation fails, changes are automatically rolled back
148+
```
149+
150+
**Raises:**
151+
- ValueError: If any configuration validation fails (with automatic rollback)
152+
"""
153+
# Create a working copy for editing
154+
working_copy = self.model_copy(deep=True)
155+
yield working_copy
156+
# Validate the working copy by creating a new instance
157+
try:
158+
validated_copy = self.__class__(**working_copy.model_dump())
159+
# If validation passes, update the original object
160+
self.__dict__.update(validated_copy.__dict__)
161+
except ValidationError as e:
162+
# Validation failed, rollback is automatic (working_copy is discarded)
163+
raise e
164+
136165

137166
class TimeBlock(_StrictBaseModel):
138167
"""
@@ -154,17 +183,18 @@ class TimeBlock(_StrictBaseModel):
154183
End time of the time block
155184
"""
156185

157-
@field_validator("end")
186+
@field_validator("start", "end")
158187
@classmethod
159188
def _validate_end_after_start(cls, v, info):
160189
"""
161190
Validate that the end time is after the start time
162191
"""
163-
if "start" in info.data:
192+
if "start" in info.data and "end" in info.data:
164193
start_time = info.data["start"]
194+
end_time = info.data["end"]
165195
# Convert time strings to minutes for comparison
166196
start_minutes = int(start_time.split(":")[0]) * 60 + int(start_time.split(":")[1])
167-
end_minutes = int(v.split(":")[0]) * 60 + int(v.split(":")[1])
197+
end_minutes = int(end_time.split(":")[0]) * 60 + int(end_time.split(":")[1])
168198

169199
if end_minutes <= start_minutes:
170200
raise ValueError("End time must be after start time")
@@ -317,6 +347,39 @@ class TimeSlotConfig(_StrictBaseModel):
317347
Minimum time overlap between time slots (default: 45)
318348
"""
319349

350+
@model_validator(mode="after")
351+
def validate(self):
352+
"""
353+
Validate that time slot config is consistent and complete.
354+
"""
355+
errors = []
356+
357+
# Check that all days in time_slot_config are valid
358+
valid_days = {"MON", "TUE", "WED", "THU", "FRI"}
359+
for day in self.times:
360+
if day not in valid_days:
361+
errors.append(f"Invalid day '{day}' in time slot configuration")
362+
363+
# Check that there are time blocks for each day
364+
for day in valid_days:
365+
if day not in self.times or not self.times[day]:
366+
errors.append(f"No time blocks defined for {day}")
367+
368+
# Check that class patterns are reasonable
369+
if not self.classes:
370+
errors.append("At least one class pattern must be defined")
371+
372+
# Check for disabled patterns
373+
disabled_patterns = [p for p in self.classes if p.disabled]
374+
if len(disabled_patterns) == len(self.classes):
375+
errors.append("All class patterns are disabled")
376+
377+
if errors:
378+
error_message = "Time slot configuration errors:\n" + "\n".join(f" - {error}" for error in errors)
379+
raise ValueError(error_message)
380+
381+
return self
382+
320383

321384
class CourseConfig(_StrictBaseModel):
322385
"""
@@ -436,9 +499,9 @@ def _convert_time_strings(cls, v):
436499
return v
437500

438501
@model_validator(mode="after")
439-
def _validate_credit_consistency(self):
502+
def validate(self):
440503
"""
441-
Validate that minimum and maximum credits are consistent
504+
Validate the model state.
442505
"""
443506
if self.minimum_credits > self.maximum_credits:
444507
raise ValueError(
@@ -500,15 +563,15 @@ class SchedulerConfig(_StrictBaseModel):
500563
"""
501564

502565
@model_validator(mode="after")
503-
def validate_references(self):
566+
def validate(self):
504567
"""
505568
Validate all cross-references between child models.
506569
This method can be called manually or is used by Pydantic validators.
507570
508571
**Usage:**
509572
```python
510573
config.courses[0].room = ["NewRoom"]
511-
config.validate_references() # Validates all cross-references
574+
config.validate() # Validates all cross-references
512575
```
513576
514577
**Raises:**
@@ -722,7 +785,13 @@ class CombinedConfig(_StrictBaseModel):
722785
time_slot_config: TimeSlotConfig = Field(
723786
description="Time slot configuration",
724787
example=TimeSlotConfig(
725-
times={"MON": [{"start": "10:00", "spacing": 60, "end": "12:00"}]},
788+
times={
789+
"MON": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
790+
"TUE": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
791+
"WED": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
792+
"THU": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
793+
"FRI": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
794+
},
726795
classes=[{"credits": 3, "meetings": [{"day": "MON", "duration": 150, "lab": False}]}],
727796
),
728797
)
@@ -761,36 +830,3 @@ def _convert_optimizer_flags(cls, v):
761830
if isinstance(v, list):
762831
return [OptimizerFlags(flag) if isinstance(flag, str) else flag for flag in v]
763832
return v
764-
765-
@model_validator(mode="after")
766-
def _validate_time_slot_config_consistency(self):
767-
"""
768-
Validate that time slot config is consistent with scheduler config
769-
"""
770-
errors = []
771-
772-
# Check that all days in time_slot_config are valid
773-
valid_days = {"MON", "TUE", "WED", "THU", "FRI"}
774-
for day in self.time_slot_config.times:
775-
if day not in valid_days:
776-
errors.append(f"Invalid day '{day}' in time slot configuration")
777-
778-
# Check that there are time blocks for each day
779-
for day in valid_days:
780-
if day not in self.time_slot_config.times or not self.time_slot_config.times[day]:
781-
errors.append(f"No time blocks defined for {day}")
782-
783-
# Check that class patterns are reasonable
784-
if not self.time_slot_config.classes:
785-
errors.append("At least one class pattern must be defined")
786-
787-
# Check for disabled patterns
788-
disabled_patterns = [p for p in self.time_slot_config.classes if p.disabled]
789-
if len(disabled_patterns) == len(self.time_slot_config.classes):
790-
errors.append("All class patterns are disabled")
791-
792-
if errors:
793-
error_message = "Time slot configuration errors:\n" + "\n".join(f" - {error}" for error in errors)
794-
raise ValueError(error_message)
795-
796-
return self

0 commit comments

Comments
 (0)