|
8 | 8 | from collections import defaultdict |
9 | 9 | from dataclasses import dataclass, field, replace |
10 | 10 | from typing import TYPE_CHECKING, Literal, overload |
| 11 | +from uuid import UUID |
11 | 12 |
|
12 | 13 | from typing_extensions import assert_never |
13 | 14 |
|
@@ -89,18 +90,140 @@ def copy(self) -> EGraphState: |
89 | 90 | cost_callables=self.cost_callables.copy(), |
90 | 91 | ) |
91 | 92 |
|
92 | | - def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: |
| 93 | + def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command: |
| 94 | + """ |
| 95 | + Turn a run schedule into an egg command. |
| 96 | +
|
| 97 | + If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise |
| 98 | + will be a normal run command. |
| 99 | + """ |
| 100 | + processed_schedule = self._process_schedule(schedule) |
| 101 | + if processed_schedule is None: |
| 102 | + return bindings.RunSchedule(self._schedule_to_egg(schedule)) |
| 103 | + top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, []) |
| 104 | + if len(top_level_schedules) == 1: |
| 105 | + schedule_expr = top_level_schedules[0] |
| 106 | + else: |
| 107 | + schedule_expr = bindings.Call(span(), "seq", top_level_schedules) |
| 108 | + return bindings.UserDefined(span(), "run-schedule", [schedule_expr]) |
| 109 | + |
| 110 | + def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None: |
| 111 | + """ |
| 112 | + Processes a schedule to determine if it contains any custom schedulers. |
| 113 | +
|
| 114 | + If it does, it returns a new schedule with all the required let bindings added to the other scope. |
| 115 | + If not, returns none. |
| 116 | +
|
| 117 | + Also processes all rulesets in the schedule to make sure they are registered. |
| 118 | + """ |
| 119 | + bound_schedulers: list[UUID] = [] |
| 120 | + unbound_schedulers: list[BackOffDecl] = [] |
| 121 | + |
| 122 | + def helper(s: ScheduleDecl) -> None: |
| 123 | + match s: |
| 124 | + case LetSchedulerDecl(scheduler, inner): |
| 125 | + bound_schedulers.append(scheduler.id) |
| 126 | + return helper(inner) |
| 127 | + case RunDecl(ruleset_name, _, scheduler): |
| 128 | + self.ruleset_to_egg(ruleset_name) |
| 129 | + if scheduler and scheduler.id not in bound_schedulers: |
| 130 | + unbound_schedulers.append(scheduler) |
| 131 | + case SaturateDecl(inner) | RepeatDecl(inner, _): |
| 132 | + return helper(inner) |
| 133 | + case SequenceDecl(schedules): |
| 134 | + for sc in schedules: |
| 135 | + helper(sc) |
| 136 | + case _: |
| 137 | + assert_never(s) |
| 138 | + return None |
| 139 | + |
| 140 | + helper(schedule) |
| 141 | + if not bound_schedulers and not unbound_schedulers: |
| 142 | + return None |
| 143 | + for scheduler in unbound_schedulers: |
| 144 | + schedule = LetSchedulerDecl(scheduler, schedule) |
| 145 | + return schedule |
| 146 | + |
| 147 | + def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: |
| 148 | + msg = "Should never reach this, let schedulers should be handled by custom scheduler" |
93 | 149 | match schedule: |
94 | 150 | case SaturateDecl(schedule): |
95 | | - return bindings.Saturate(span(), self.schedule_to_egg(schedule)) |
| 151 | + return bindings.Saturate(span(), self._schedule_to_egg(schedule)) |
96 | 152 | case RepeatDecl(schedule, times): |
97 | | - return bindings.Repeat(span(), times, self.schedule_to_egg(schedule)) |
| 153 | + return bindings.Repeat(span(), times, self._schedule_to_egg(schedule)) |
98 | 154 | case SequenceDecl(schedules): |
99 | | - return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules]) |
100 | | - case RunDecl(ruleset_name, until): |
101 | | - self.ruleset_to_egg(ruleset_name) |
| 155 | + return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules]) |
| 156 | + case RunDecl(ruleset_name, until, scheduler): |
| 157 | + if scheduler is not None: |
| 158 | + raise ValueError(msg) |
102 | 159 | config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until))) |
103 | 160 | return bindings.Run(span(), config) |
| 161 | + case LetSchedulerDecl(): |
| 162 | + raise ValueError(msg) |
| 163 | + case _: |
| 164 | + assert_never(schedule) |
| 165 | + |
| 166 | + def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 |
| 167 | + self, schedule: ScheduleDecl, bound_schedulers: list[UUID] |
| 168 | + ) -> list[bindings._Expr]: |
| 169 | + """ |
| 170 | + Turns a scheduler into an egg expression, to be used with a custom extract command. |
| 171 | +
|
| 172 | + The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`. |
| 173 | + """ |
| 174 | + match schedule: |
| 175 | + case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner): |
| 176 | + name = f"_scheduler_{len(bound_schedulers)}" |
| 177 | + bound_schedulers.append(id) |
| 178 | + args: list[bindings._Expr] = [] |
| 179 | + if match_limit is not None: |
| 180 | + args.append(bindings.Var(span(), ":match-limit")) |
| 181 | + args.append(bindings.Lit(span(), bindings.Int(match_limit))) |
| 182 | + if ban_length is not None: |
| 183 | + args.append(bindings.Var(span(), ":ban-length")) |
| 184 | + args.append(bindings.Lit(span(), bindings.Int(ban_length))) |
| 185 | + back_off_decl = bindings.Call(span(), "back-off", args) |
| 186 | + let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl]) |
| 187 | + return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)] |
| 188 | + case RunDecl(ruleset_name, until, scheduler): |
| 189 | + args = [bindings.Var(span(), ruleset_name)] |
| 190 | + if scheduler: |
| 191 | + name = "run-with" |
| 192 | + scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}" |
| 193 | + args.insert(0, bindings.Var(span(), scheduler_name)) |
| 194 | + else: |
| 195 | + name = "run" |
| 196 | + if until: |
| 197 | + if len(until) > 1: |
| 198 | + msg = "Can only have one until fact with custom scheduler" |
| 199 | + raise ValueError(msg) |
| 200 | + args.append(bindings.Var(span(), ":until")) |
| 201 | + fact_egg = self.fact_to_egg(until[0]) |
| 202 | + if isinstance(fact_egg, bindings.Eq): |
| 203 | + msg = "Cannot use equality fact with custom scheduler" |
| 204 | + raise ValueError(msg) |
| 205 | + args.append(fact_egg.expr) |
| 206 | + return [bindings.Call(span(), name, args)] |
| 207 | + case SaturateDecl(inner): |
| 208 | + return [ |
| 209 | + bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers)) |
| 210 | + ] |
| 211 | + case RepeatDecl(inner, times): |
| 212 | + return [ |
| 213 | + bindings.Call( |
| 214 | + span(), |
| 215 | + "repeat", |
| 216 | + [ |
| 217 | + bindings.Lit(span(), bindings.Int(times)), |
| 218 | + *self._schedule_with_scheduler_to_egg(inner, bound_schedulers), |
| 219 | + ], |
| 220 | + ) |
| 221 | + ] |
| 222 | + case SequenceDecl(schedules): |
| 223 | + res = [] |
| 224 | + for s in schedules: |
| 225 | + res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers)) |
| 226 | + return res |
104 | 227 | case _: |
105 | 228 | assert_never(schedule) |
106 | 229 |
|
|
0 commit comments