Skip to content

Commit 98faefd

Browse files
authored
Rewrite gemini.logicalInitialize to place.Initialize. (#50)
* rename rewrite and flattening qubit ids for CZ * removing helper function * Adding merge for initialize * Renaming test file * Adding test for initialize * removing cached property to prevent potential bugs
1 parent 4d91bac commit 98faefd

File tree

4 files changed

+105
-39
lines changed

4 files changed

+105
-39
lines changed

src/bloqade/lanes/dialects/place.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,15 @@ class Initialize(QuantumStmt):
4141

4242
@statement(dialect=dialect)
4343
class CZ(QuantumStmt):
44-
targets: tuple[int, ...] = info.attribute()
45-
controls: tuple[int, ...] = info.attribute()
44+
qubits: tuple[int, ...] = info.attribute()
45+
46+
@property
47+
def controls(self) -> tuple[int, ...]:
48+
return self.qubits[: len(self.qubits) // 2]
49+
50+
@property
51+
def targets(self) -> tuple[int, ...]:
52+
return self.qubits[len(self.qubits) // 2 :]
4653

4754

4855
@statement(dialect=dialect)

src/bloqade/lanes/rewrite/native2place.py renamed to src/bloqade/lanes/rewrite/circuit2place.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,20 @@ class RewritePlaceOperations(abc.RewriteRule):
1818
This is a placeholder for the actual implementation.
1919
"""
2020

21-
def default_(self, node: ir.Statement) -> abc.RewriteResult:
22-
return abc.RewriteResult()
23-
2421
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
22+
if not isinstance(
23+
node,
24+
(
25+
gemini_stmts.TerminalLogicalMeasurement,
26+
gemini_stmts.Initialize,
27+
gate.CZ,
28+
gate.R,
29+
gate.Rz,
30+
),
31+
):
32+
return abc.RewriteResult()
2533
rewrite_method_name = f"rewrite_{type(node).__name__}"
26-
rewrite_method = getattr(self, rewrite_method_name, self.default_)
34+
rewrite_method = getattr(self, rewrite_method_name)
2735
return rewrite_method(node)
2836

2937
def prep_region(self) -> tuple[ir.Region, ir.Block, ir.SSAValue]:
@@ -46,6 +54,25 @@ def construct_execute(
4654

4755
return place.StaticPlacement(qubits=qubits, body=body)
4856

57+
def rewrite_Initialize(self, node: gemini_stmts.Initialize) -> abc.RewriteResult:
58+
if not isinstance(args_list := node.qubits.owner, ilist.New):
59+
return abc.RewriteResult()
60+
61+
inputs = args_list.values
62+
body, block, entry_state = self.prep_region()
63+
gate_stmt = place.Initialize(
64+
entry_state,
65+
phi=node.phi,
66+
theta=node.theta,
67+
lam=node.lam,
68+
qubits=tuple(range(len(inputs))),
69+
)
70+
node.replace_by(
71+
self.construct_execute(gate_stmt, qubits=inputs, body=body, block=block)
72+
)
73+
74+
return abc.RewriteResult(has_done_something=True)
75+
4976
def rewrite_TerminalLogicalMeasurement(
5077
self, node: gemini_stmts.TerminalLogicalMeasurement
5178
) -> abc.RewriteResult:
@@ -82,13 +109,11 @@ def rewrite_CZ(self, node: gate.CZ) -> abc.RewriteResult:
82109
return abc.RewriteResult()
83110

84111
all_qubits = tuple(range(len(targets) + len(controls)))
85-
n_controls = len(controls)
86112

87113
body, block, entry_state = self.prep_region()
88114
stmt = place.CZ(
89115
entry_state,
90-
controls=all_qubits[:n_controls],
91-
targets=all_qubits[n_controls:],
116+
qubits=all_qubits,
92117
)
93118

94119
node.replace_by(
@@ -155,27 +180,6 @@ class MergePlacementRegions(abc.RewriteRule):
155180
merge_heuristic: Callable[[ir.Region, ir.Region], bool] = _default_merge_heuristic
156181
"""Heuristic function to decide whether to merge two circuit regions."""
157182

158-
def remap_qubits(
159-
self,
160-
curr_state: ir.SSAValue,
161-
node: place.R | place.Rz | place.CZ | place.EndMeasure,
162-
input_map: dict[int, int],
163-
):
164-
if isinstance(node, place.CZ):
165-
return place.CZ(
166-
curr_state,
167-
targets=tuple(input_map[i] for i in node.targets),
168-
controls=tuple(input_map[i] for i in node.controls),
169-
)
170-
else:
171-
return node.from_stmt(
172-
node,
173-
args=(curr_state, *node.args[1:]),
174-
attributes={
175-
"qubits": ir.PyAttr(tuple(input_map[i] for i in node.qubits))
176-
},
177-
)
178-
179183
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
180184
if not (
181185
isinstance(node, place.StaticPlacement)
@@ -206,8 +210,18 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
206210
curr_yield.delete()
207211

208212
for stmt in next_node.body.blocks[0].stmts:
209-
if isinstance(stmt, (place.R, place.Rz, place.CZ, place.EndMeasure)):
210-
remapped_stmt = self.remap_qubits(curr_state, stmt, new_input_map)
213+
if isinstance(
214+
stmt, (place.R, place.Rz, place.CZ, place.EndMeasure, place.Initialize)
215+
):
216+
remapped_stmt = stmt.from_stmt(
217+
stmt,
218+
args=(curr_state, *stmt.args[1:]),
219+
attributes={
220+
"qubits": ir.PyAttr(
221+
tuple(new_input_map[i] for i in stmt.qubits)
222+
)
223+
},
224+
)
211225
curr_state = remapped_stmt.results[0]
212226
new_block.stmts.append(remapped_stmt)
213227
for old_result, new_result in zip(

src/bloqade/lanes/upstream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from bloqade.lanes.analysis import layout, placement
1111
from bloqade.lanes.dialects import move, place
1212
from bloqade.lanes.passes.canonicalize import CanonicalizeNative
13-
from bloqade.lanes.rewrite import native2place, place2move
13+
from bloqade.lanes.rewrite import circuit2place, place2move
1414

1515

1616
def default_merge_heuristic(region_a: ir.Region, region_b: ir.Region) -> bool:
@@ -29,11 +29,11 @@ def emit(self, mt: Method, no_raise: bool = True):
2929
AggressiveUnroll(out.dialects, no_raise=no_raise).fixpoint(out)
3030
CanonicalizeNative(out.dialects, no_raise=no_raise).fixpoint(out)
3131
rewrite.Walk(
32-
native2place.RewritePlaceOperations(),
32+
circuit2place.RewritePlaceOperations(),
3333
).rewrite(out.code)
3434

3535
rewrite.Fixpoint(
36-
rewrite.Walk(native2place.MergePlacementRegions(self.merge_heuristic))
36+
rewrite.Walk(circuit2place.MergePlacementRegions(self.merge_heuristic))
3737
).rewrite(out.code)
3838
passes.TypeInfer(out.dialects)(out)
3939

test/rewrite/test_native2circuit.py renamed to test/rewrite/test_circuit2place.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from bloqade.lanes import types
88
from bloqade.lanes.dialects import place
9-
from bloqade.lanes.rewrite.native2place import (
9+
from bloqade.lanes.rewrite.circuit2place import (
1010
MergePlacementRegions,
1111
RewritePlaceOperations,
1212
)
@@ -33,9 +33,7 @@ def test_cz():
3333
)
3434

3535
entry_state = block.args.append_from(types.StateType, name="entry_state")
36-
block.stmts.append(
37-
gate_stmt := place.CZ(entry_state, controls=(0, 1), targets=(2, 3))
38-
)
36+
block.stmts.append(gate_stmt := place.CZ(entry_state, qubits=(0, 1, 2, 3)))
3937
block.stmts.append(place.Yield(gate_stmt.state_after))
4038

4139
rule = rewrite.Walk(RewritePlaceOperations())
@@ -153,6 +151,53 @@ def test_measurement():
153151
assert_nodes(test_block, expected_block)
154152

155153

154+
def test_initialize():
155+
test_block = ir.Block(
156+
[
157+
qubits := ilist.New(
158+
values=(
159+
q0 := ir.TestValue(),
160+
q1 := ir.TestValue(),
161+
q2 := ir.TestValue(),
162+
)
163+
),
164+
gemini_stmts.Initialize(
165+
theta := ir.TestValue(),
166+
phi := ir.TestValue(),
167+
lam := ir.TestValue(),
168+
qubits=qubits.result,
169+
),
170+
],
171+
)
172+
173+
expected_block = ir.Block(
174+
[
175+
qubits := ilist.New(values=(q0, q1, q2)),
176+
]
177+
)
178+
179+
block = ir.Block()
180+
181+
entry_state = block.args.append_from(types.StateType, name="entry_state")
182+
block.stmts.append(
183+
gate_stmt := place.Initialize(
184+
entry_state,
185+
theta=theta,
186+
phi=phi,
187+
lam=lam,
188+
qubits=(0, 1, 2),
189+
)
190+
)
191+
block.stmts.append(place.Yield(gate_stmt.state_after))
192+
expected_block.stmts.append(
193+
place.StaticPlacement(qubits=(q0, q1, q2), body=ir.Region(block))
194+
)
195+
rule = rewrite.Walk(RewritePlaceOperations())
196+
197+
rule.rewrite(test_block)
198+
assert_nodes(test_block, expected_block)
199+
200+
156201
def test_merge_regions():
157202

158203
qubits = tuple(ir.TestValue() for _ in range(10))

0 commit comments

Comments
 (0)