Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 77 additions & 41 deletions src/scheduler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ConfigDict,
Field,
PositiveInt,
ValidationError,
field_validator,
model_validator,
)
Expand Down Expand Up @@ -133,6 +134,34 @@ class _StrictBaseModel(BaseModel):
Configuration for the model which forbids extra fields, is strict, and validates on assignment (@private)
"""

@contextmanager
def edit_mode(self):
"""
Context manager for making multiple changes with automatic rollback on validation failure.

**Usage:**
```python
with config.edit_mode() as editable_config:
editable_config.some_field = "new_value"
editable_config.another_field.append("item")
# If validation fails, changes are automatically rolled back
```

**Raises:**
- ValueError: If any configuration validation fails (with automatic rollback)
"""
# Create a working copy for editing
working_copy = self.model_copy(deep=True)
yield working_copy
# Validate the working copy by creating a new instance
try:
validated_copy = self.__class__(**working_copy.model_dump())
# If validation passes, update the original object
self.__dict__.update(validated_copy.__dict__)
except ValidationError as e:
# Validation failed, rollback is automatic (working_copy is discarded)
raise e


class TimeBlock(_StrictBaseModel):
"""
Expand All @@ -154,17 +183,18 @@ class TimeBlock(_StrictBaseModel):
End time of the time block
"""

@field_validator("end")
@field_validator("start", "end")
@classmethod
def _validate_end_after_start(cls, v, info):
"""
Validate that the end time is after the start time
"""
if "start" in info.data:
if "start" in info.data and "end" in info.data:
start_time = info.data["start"]
end_time = info.data["end"]
# Convert time strings to minutes for comparison
start_minutes = int(start_time.split(":")[0]) * 60 + int(start_time.split(":")[1])
end_minutes = int(v.split(":")[0]) * 60 + int(v.split(":")[1])
end_minutes = int(end_time.split(":")[0]) * 60 + int(end_time.split(":")[1])

if end_minutes <= start_minutes:
raise ValueError("End time must be after start time")
Expand Down Expand Up @@ -317,6 +347,39 @@ class TimeSlotConfig(_StrictBaseModel):
Minimum time overlap between time slots (default: 45)
"""

@model_validator(mode="after")
def validate(self):
"""
Validate that time slot config is consistent and complete.
"""
errors = []

# Check that all days in time_slot_config are valid
valid_days = {"MON", "TUE", "WED", "THU", "FRI"}
for day in self.times:
if day not in valid_days:
errors.append(f"Invalid day '{day}' in time slot configuration")

# Check that there are time blocks for each day
for day in valid_days:
if day not in self.times or not self.times[day]:
errors.append(f"No time blocks defined for {day}")

# Check that class patterns are reasonable
if not self.classes:
errors.append("At least one class pattern must be defined")

# Check for disabled patterns
disabled_patterns = [p for p in self.classes if p.disabled]
if len(disabled_patterns) == len(self.classes):
errors.append("All class patterns are disabled")

if errors:
error_message = "Time slot configuration errors:\n" + "\n".join(f" - {error}" for error in errors)
raise ValueError(error_message)

return self


class CourseConfig(_StrictBaseModel):
"""
Expand Down Expand Up @@ -436,9 +499,9 @@ def _convert_time_strings(cls, v):
return v

@model_validator(mode="after")
def _validate_credit_consistency(self):
def validate(self):
"""
Validate that minimum and maximum credits are consistent
Validate the model state.
"""
if self.minimum_credits > self.maximum_credits:
raise ValueError(
Expand Down Expand Up @@ -500,15 +563,15 @@ class SchedulerConfig(_StrictBaseModel):
"""

@model_validator(mode="after")
def validate_references(self):
def validate(self):
"""
Validate all cross-references between child models.
This method can be called manually or is used by Pydantic validators.

**Usage:**
```python
config.courses[0].room = ["NewRoom"]
config.validate_references() # Validates all cross-references
config.validate() # Validates all cross-references
```

**Raises:**
Expand Down Expand Up @@ -722,7 +785,13 @@ class CombinedConfig(_StrictBaseModel):
time_slot_config: TimeSlotConfig = Field(
description="Time slot configuration",
example=TimeSlotConfig(
times={"MON": [{"start": "10:00", "spacing": 60, "end": "12:00"}]},
times={
"MON": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
"TUE": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
"WED": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
"THU": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
"FRI": [{"start": "10:00", "spacing": 60, "end": "12:00"}],
},
classes=[{"credits": 3, "meetings": [{"day": "MON", "duration": 150, "lab": False}]}],
),
)
Expand Down Expand Up @@ -761,36 +830,3 @@ def _convert_optimizer_flags(cls, v):
if isinstance(v, list):
return [OptimizerFlags(flag) if isinstance(flag, str) else flag for flag in v]
return v

@model_validator(mode="after")
def _validate_time_slot_config_consistency(self):
"""
Validate that time slot config is consistent with scheduler config
"""
errors = []

# Check that all days in time_slot_config are valid
valid_days = {"MON", "TUE", "WED", "THU", "FRI"}
for day in self.time_slot_config.times:
if day not in valid_days:
errors.append(f"Invalid day '{day}' in time slot configuration")

# Check that there are time blocks for each day
for day in valid_days:
if day not in self.time_slot_config.times or not self.time_slot_config.times[day]:
errors.append(f"No time blocks defined for {day}")

# Check that class patterns are reasonable
if not self.time_slot_config.classes:
errors.append("At least one class pattern must be defined")

# Check for disabled patterns
disabled_patterns = [p for p in self.time_slot_config.classes if p.disabled]
if len(disabled_patterns) == len(self.time_slot_config.classes):
errors.append("All class patterns are disabled")

if errors:
error_message = "Time slot configuration errors:\n" + "\n".join(f" - {error}" for error in errors)
raise ValueError(error_message)

return self