Skip to content

Commit c8d24a3

Browse files
Merge pull request #350 from egraphs-good/schedule
Add `back_off` scheduler
2 parents c7ba30d + 642ea3e commit c8d24a3

File tree

10 files changed

+382
-22
lines changed

10 files changed

+382
-22
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Add `back_off` scheduler [#350](https://github.com/egraphs-good/egglog-python/pull/350)
78
## 11.2.0 (2025-09-03)
89

910
- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343)

docs/reference/egglog-translation.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,62 @@ step_egraph.check(left(i64(10)), right(i64(9)))
459459
step_egraph.check_fail(left(i64(11)), right(i64(10)))
460460
```
461461

462+
#### Custom Schedulers
463+
464+
Custom backoff scheduler from egglog-experimental is supported. Create a custom backoff scheduler with `bo = BackOff(match_limit: None | int=None, ban_length: None | int=None)`, then run using `run(ruleset, *facts, scheduler=bo)`:
465+
466+
- `match_limit`: per-rule threshold of matches allowed in a single scheduler iteration. If a rule produces more matches than the threshold, that rule is temporarily banned.
467+
- `ban_length`: initial ban duration (in scheduler iterations). While banned, that rule is skipped.
468+
- Exponential backoff: each time a rule is banned, both the threshold and ban length double for that rule (threshold = match_limit << times_banned; ban = ban_length << times_banned).
469+
- Fast-forwarding: when any rule is banned, the scheduler fast-forwards by the minimum remaining ban to unban at least one rule before checking for termination again.
470+
- Defaults: match_limit defaults to 1000; ban_length defaults to 5.
471+
472+
For example, this egglog code:
473+
474+
```
475+
(run-schedule
476+
(let-scheduler bo (back-off :match-limit 10))
477+
(repeat 10 (run-with bo step_right)))
478+
```
479+
480+
Is translated as:
481+
482+
```{code-cell} python
483+
step_egraph.run(
484+
run(step_right, scheduler=back_off(match_limit=10)) * 10
485+
)
486+
```
487+
488+
By default the scheduler will be created before any other schedules are run.
489+
To control where is instantiated explicitly, use `bo.scope(<schedule>)`, where it will be created before everything in `<schedule>`.
490+
491+
So the previous is equivalent to:
492+
493+
```{code-cell} python
494+
bo = back_off(match_limit=10)
495+
step_egraph.run(
496+
bo.scope(run(step_right, scheduler=bo) * 10)
497+
)
498+
```
499+
500+
If you wanted to create the scheduler inside the repeated schedule, you can do:
501+
502+
```{code-cell} python
503+
bo = back_off(match_limit=10)
504+
step_egraph.run(
505+
bo.scope(run(step_right, scheduler=bo)) * 10
506+
)
507+
```
508+
509+
This would be equivalent to this egglog:
510+
511+
```
512+
(run-schedule
513+
(repeat 10
514+
(let-scheduler bo (back-off :match-limit 10))
515+
(run-with bo step_right)))
516+
```
517+
462518
## Check
463519

464520
The `(check ...)` command to verify that some facts are true, can be translated to Python with the `egraph.check` function:

python/egglog/declarations.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dataclasses import dataclass, field
1010
from functools import cached_property
1111
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12+
from uuid import UUID
1213
from weakref import WeakValueDictionary
1314

1415
from typing_extensions import Self, assert_never
@@ -20,6 +21,7 @@
2021
__all__ = [
2122
"ActionCommandDecl",
2223
"ActionDecl",
24+
"BackOffDecl",
2325
"BiRewriteDecl",
2426
"CallDecl",
2527
"CallableDecl",
@@ -52,6 +54,7 @@
5254
"JustTypeRef",
5355
"LetDecl",
5456
"LetRefDecl",
57+
"LetSchedulerDecl",
5558
"LitDecl",
5659
"LitType",
5760
"MethodRef",
@@ -790,9 +793,24 @@ class SequenceDecl:
790793
class RunDecl:
791794
ruleset: str
792795
until: tuple[FactDecl, ...] | None
796+
scheduler: BackOffDecl | None = None
793797

794798

795-
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
799+
@dataclass(frozen=True)
800+
class LetSchedulerDecl:
801+
scheduler: BackOffDecl
802+
inner: ScheduleDecl
803+
804+
805+
ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl
806+
807+
808+
@dataclass(frozen=True)
809+
class BackOffDecl:
810+
id: UUID
811+
match_limit: int | None
812+
ban_length: int | None
813+
796814

797815
##
798816
# Facts

python/egglog/egraph.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_type_hints,
2424
overload,
2525
)
26+
from uuid import uuid4
2627
from warnings import warn
2728

2829
import graphviz
@@ -45,6 +46,7 @@
4546

4647
__all__ = [
4748
"Action",
49+
"BackOff",
4850
"BaseExpr",
4951
"BuiltinExpr",
5052
"Command",
@@ -63,6 +65,7 @@
6365
"_RewriteBuilder",
6466
"_SetBuilder",
6567
"_UnionBuilder",
68+
"back_off",
6669
"birewrite",
6770
"check",
6871
"check_eq",
@@ -905,8 +908,8 @@ def run(
905908

906909
def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
907910
self._add_decls(schedule)
908-
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
909-
(command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule))
911+
cmd = self._state.run_schedule_to_egg(schedule.schedule)
912+
(command_output,) = self._egraph.run_program(cmd)
910913
assert isinstance(command_output, bindings.RunScheduleOutput)
911914
return command_output.report
912915

@@ -1786,17 +1789,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr:
17861789
return expr
17871790

17881791

1789-
def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
1792+
def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule:
17901793
"""
17911794
Create a run configuration.
17921795
"""
17931796
facts = _fact_likes(until)
17941797
return Schedule(
17951798
Thunk.fn(Declarations.create, ruleset, *facts),
1796-
RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1799+
RunDecl(
1800+
ruleset.__egg_name__ if ruleset else "",
1801+
tuple(f.fact for f in facts) or None,
1802+
scheduler.scheduler if scheduler else None,
1803+
),
17971804
)
17981805

17991806

1807+
def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff:
1808+
"""
1809+
Create a backoff scheduler configuration.
1810+
1811+
```python
1812+
schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
1813+
```
1814+
This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler.
1815+
"""
1816+
return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length))
1817+
1818+
1819+
@dataclass(frozen=True)
1820+
class BackOff:
1821+
scheduler: BackOffDecl
1822+
1823+
def scope(self, schedule: Schedule) -> Schedule:
1824+
"""
1825+
Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the
1826+
most outer scope.
1827+
"""
1828+
return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule))
1829+
1830+
def __str__(self) -> str:
1831+
return pretty_decl(Declarations(), self.scheduler)
1832+
1833+
def __repr__(self) -> str:
1834+
return str(self)
1835+
1836+
18001837
def seq(*schedules: Schedule) -> Schedule:
18011838
"""
18021839
Run a sequence of schedules.

python/egglog/egraph_state.py

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections import defaultdict
99
from dataclasses import dataclass, field, replace
1010
from typing import TYPE_CHECKING, Literal, overload
11+
from uuid import UUID
1112

1213
from typing_extensions import assert_never
1314

@@ -89,18 +90,140 @@ def copy(self) -> EGraphState:
8990
cost_callables=self.cost_callables.copy(),
9091
)
9192

92-
def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
93+
def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
94+
"""
95+
Turn a run schedule into an egg command.
96+
97+
If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
98+
will be a normal run command.
99+
"""
100+
processed_schedule = self._process_schedule(schedule)
101+
if processed_schedule is None:
102+
return bindings.RunSchedule(self._schedule_to_egg(schedule))
103+
top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
104+
if len(top_level_schedules) == 1:
105+
schedule_expr = top_level_schedules[0]
106+
else:
107+
schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
108+
return bindings.UserDefined(span(), "run-schedule", [schedule_expr])
109+
110+
def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
111+
"""
112+
Processes a schedule to determine if it contains any custom schedulers.
113+
114+
If it does, it returns a new schedule with all the required let bindings added to the other scope.
115+
If not, returns none.
116+
117+
Also processes all rulesets in the schedule to make sure they are registered.
118+
"""
119+
bound_schedulers: list[UUID] = []
120+
unbound_schedulers: list[BackOffDecl] = []
121+
122+
def helper(s: ScheduleDecl) -> None:
123+
match s:
124+
case LetSchedulerDecl(scheduler, inner):
125+
bound_schedulers.append(scheduler.id)
126+
return helper(inner)
127+
case RunDecl(ruleset_name, _, scheduler):
128+
self.ruleset_to_egg(ruleset_name)
129+
if scheduler and scheduler.id not in bound_schedulers:
130+
unbound_schedulers.append(scheduler)
131+
case SaturateDecl(inner) | RepeatDecl(inner, _):
132+
return helper(inner)
133+
case SequenceDecl(schedules):
134+
for sc in schedules:
135+
helper(sc)
136+
case _:
137+
assert_never(s)
138+
return None
139+
140+
helper(schedule)
141+
if not bound_schedulers and not unbound_schedulers:
142+
return None
143+
for scheduler in unbound_schedulers:
144+
schedule = LetSchedulerDecl(scheduler, schedule)
145+
return schedule
146+
147+
def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
148+
msg = "Should never reach this, let schedulers should be handled by custom scheduler"
93149
match schedule:
94150
case SaturateDecl(schedule):
95-
return bindings.Saturate(span(), self.schedule_to_egg(schedule))
151+
return bindings.Saturate(span(), self._schedule_to_egg(schedule))
96152
case RepeatDecl(schedule, times):
97-
return bindings.Repeat(span(), times, self.schedule_to_egg(schedule))
153+
return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
98154
case SequenceDecl(schedules):
99-
return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules])
100-
case RunDecl(ruleset_name, until):
101-
self.ruleset_to_egg(ruleset_name)
155+
return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
156+
case RunDecl(ruleset_name, until, scheduler):
157+
if scheduler is not None:
158+
raise ValueError(msg)
102159
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
103160
return bindings.Run(span(), config)
161+
case LetSchedulerDecl():
162+
raise ValueError(msg)
163+
case _:
164+
assert_never(schedule)
165+
166+
def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
167+
self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
168+
) -> list[bindings._Expr]:
169+
"""
170+
Turns a scheduler into an egg expression, to be used with a custom extract command.
171+
172+
The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
173+
"""
174+
match schedule:
175+
case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
176+
name = f"_scheduler_{len(bound_schedulers)}"
177+
bound_schedulers.append(id)
178+
args: list[bindings._Expr] = []
179+
if match_limit is not None:
180+
args.append(bindings.Var(span(), ":match-limit"))
181+
args.append(bindings.Lit(span(), bindings.Int(match_limit)))
182+
if ban_length is not None:
183+
args.append(bindings.Var(span(), ":ban-length"))
184+
args.append(bindings.Lit(span(), bindings.Int(ban_length)))
185+
back_off_decl = bindings.Call(span(), "back-off", args)
186+
let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
187+
return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
188+
case RunDecl(ruleset_name, until, scheduler):
189+
args = [bindings.Var(span(), ruleset_name)]
190+
if scheduler:
191+
name = "run-with"
192+
scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
193+
args.insert(0, bindings.Var(span(), scheduler_name))
194+
else:
195+
name = "run"
196+
if until:
197+
if len(until) > 1:
198+
msg = "Can only have one until fact with custom scheduler"
199+
raise ValueError(msg)
200+
args.append(bindings.Var(span(), ":until"))
201+
fact_egg = self.fact_to_egg(until[0])
202+
if isinstance(fact_egg, bindings.Eq):
203+
msg = "Cannot use equality fact with custom scheduler"
204+
raise ValueError(msg)
205+
args.append(fact_egg.expr)
206+
return [bindings.Call(span(), name, args)]
207+
case SaturateDecl(inner):
208+
return [
209+
bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
210+
]
211+
case RepeatDecl(inner, times):
212+
return [
213+
bindings.Call(
214+
span(),
215+
"repeat",
216+
[
217+
bindings.Lit(span(), bindings.Int(times)),
218+
*self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
219+
],
220+
)
221+
]
222+
case SequenceDecl(schedules):
223+
res = []
224+
for s in schedules:
225+
res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
226+
return res
104227
case _:
105228
assert_never(schedule)
106229

python/egglog/examples/jointree.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,3 @@ def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize:
6262
egraph.run(1000)
6363
print(egraph.extract(query))
6464
print(egraph.extract(query.size))
65-
66-
67-
egraph

0 commit comments

Comments
 (0)