Skip to content

Commit 1a3b36a

Browse files
authored
Adding tweezer task. (#16)
* updating interface * adding task function for tweezer calls inside Auto retion * adding new statement representing an unassigned tweezer function for inside Auto * allow for as well as invoke
1 parent c8edd3d commit 1a3b36a

File tree

7 files changed

+56
-3
lines changed

7 files changed

+56
-3
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# misc
2+
3+
debug/
4+
15
# Byte-compiled / optimized / DLL files
26
__pycache__/
37
*.py[cod]

src/bloqade/shuttle/dialects/schedule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Auto as Auto,
77
ExecutableRegion as ExecutableRegion,
88
NewDeviceFunction as NewDeviceFunction,
9+
NewTweezerTask as NewTweezerTask,
910
Parallel as Parallel,
1011
Reverse as Reverse,
1112
)

src/bloqade/shuttle/dialects/schedule/_interface.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Any, ParamSpec
1+
from typing import Any, ContextManager, ParamSpec
22

33
from kirin import ir
44
from kirin.dialects import ilist
55
from kirin.lowering import wraps as _wraps
66

77
from .stmts import (
8+
Auto,
89
NewDeviceFunction,
10+
Parallel,
911
Reverse,
1012
)
1113
from .types import DeviceFunction, ReverseDeviceFunction
@@ -43,3 +45,11 @@ def reverse(
4345
4446
"""
4547
...
48+
49+
50+
@_wraps(Parallel)
51+
def parallel() -> ContextManager: ...
52+
53+
54+
@_wraps(Auto)
55+
def auto() -> ContextManager: ...

src/bloqade/shuttle/dialects/schedule/stmts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
NyTones = types.TypeVar("NyTones")
1212

1313

14+
@statement(dialect=dialect)
15+
class NewTweezerTask(ir.Statement):
16+
name = "tweezer_task"
17+
traits = frozenset({ir.Pure()})
18+
move_fn: ir.SSAValue = info.argument(types.MethodType)
19+
result: ir.ResultValue = info.result(DeviceFunctionType)
20+
21+
1422
@statement(dialect=dialect)
1523
class NewDeviceFunction(ir.Statement):
1624
name = "device_function"

src/bloqade/shuttle/passes/schedule2path.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from bloqade.shuttle.rewrite.schedule2path import (
1212
Canonicalize,
13+
RewriteAutoInvoke,
1314
RewriteDeviceCall,
1415
RewriteScheduleRegion,
1516
)
@@ -20,7 +21,11 @@ class ScheduleToPath(Pass):
2021

2122
def unsafe_run(self, mt: ir.Method):
2223
result = Fixpoint(Walk(Canonicalize())).rewrite(mt.code)
23-
result = Walk(RewriteDeviceCall()).rewrite(mt.code).join(result)
24+
result = (
25+
Walk(Chain(RewriteAutoInvoke(), RewriteDeviceCall()))
26+
.rewrite(mt.code)
27+
.join(result)
28+
)
2429
result = Walk(RewriteScheduleRegion()).rewrite(mt.code).join(result)
2530
result = (
2631
Fixpoint(

src/bloqade/shuttle/rewrite/auto_scheduler.py

Whitespace-only changes.

src/bloqade/shuttle/rewrite/schedule2path.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from kirin import ir
2-
from kirin.dialects import func
2+
from kirin.dialects import func, py
33
from kirin.rewrite import abc
44

55
from bloqade.shuttle.dialects import path, schedule
@@ -24,6 +24,31 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
2424
return abc.RewriteResult(has_done_something=True)
2525

2626

27+
class RewriteAutoInvoke(abc.RewriteRule):
28+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
29+
if not isinstance(node.parent_stmt, schedule.Auto):
30+
return abc.RewriteResult()
31+
32+
if isinstance(node, func.Invoke):
33+
(callee_stmt := py.Constant(node.callee)).insert_before(node)
34+
callee_ssa = callee_stmt.result
35+
elif isinstance(node, func.Call):
36+
callee_ssa = node.callee
37+
else:
38+
return abc.RewriteResult()
39+
40+
(tweezer_task := schedule.NewTweezerTask(move_fn=callee_ssa)).insert_before(
41+
node
42+
)
43+
(path.Gen(tweezer_task.result, node.inputs, kwargs=node.kwargs)).insert_before(
44+
node
45+
)
46+
47+
node.delete()
48+
49+
return abc.RewriteResult(has_done_something=True)
50+
51+
2752
class RewriteScheduleRegion(abc.RewriteRule):
2853
CLASSES = {
2954
schedule.Auto: path.Auto,

0 commit comments

Comments
 (0)