Skip to content

Commit f88969e

Browse files
authored
Updating names and minor bug fixes. (#46)
* rename circuit to place * Fixing issue with interpreter * updating name for rewrite rule * renaming modules * Renaming a rewrite * testing compiler pipeline with measurement * fixing demo * fixing dialect name
1 parent 1fd2a95 commit f88969e

File tree

9 files changed

+149
-136
lines changed

9 files changed

+149
-136
lines changed

demo/ghz_moves_demo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from bloqade.lanes import visualize
77
from bloqade.lanes.arch.gemini import logical
88
from bloqade.lanes.heuristics import fixed
9-
from bloqade.lanes.upstream import CircuitToMove, NativeToCircuit
9+
from bloqade.lanes.upstream import NativeToPlace, PlaceToMove
1010

1111

1212
@squin.kernel(typeinfer=True, fold=True)
@@ -44,8 +44,8 @@ def compile_and_visualize(mt: ir.Method, interactive=True):
4444
# Compile to move dialect
4545

4646
mt = SquinToNative().emit(mt)
47-
mt = NativeToCircuit().emit(mt)
48-
mt = CircuitToMove(
47+
mt = NativeToPlace().emit(mt)
48+
mt = PlaceToMove(
4949
fixed.LogicalLayoutHeuristic(),
5050
fixed.LogicalPlacementStrategy(),
5151
fixed.LogicalMoveScheduler(),

demo/pipeline_demo.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,38 @@
1+
from bloqade.gemini.dialects import logical as gemini_logical
12
from bloqade.native.upstream import SquinToNative
2-
from kirin.dialects import ilist
33

44
from bloqade import qubit, squin
5-
from bloqade.lanes.arch.gemini.logical import SpecializeGemini
65
from bloqade.lanes.heuristics import fixed
7-
from bloqade.lanes.upstream import CircuitToMove, NativeToCircuit
6+
from bloqade.lanes.upstream import NativeToPlace, PlaceToMove
87

8+
kernel = squin.kernel.add(gemini_logical.dialect)
9+
kernel.run_pass = squin.kernel.run_pass
910

10-
@squin.kernel(typeinfer=True, fold=True)
11-
def log_depth_ghz():
11+
12+
@kernel(typeinfer=True, fold=True)
13+
def ghz_optimal():
1214
size = 10
13-
q0 = qubit.new()
14-
squin.h(q0)
15-
reg = ilist.IList([q0])
16-
for i in range(size):
17-
current = len(reg)
18-
missing = size - current
19-
if missing > current:
20-
num_alloc = current
21-
else:
22-
num_alloc = missing
15+
qs = qubit.qalloc(size)
16+
squin.h(qs[0])
17+
squin.cx(qs[0], qs[1])
18+
squin.broadcast.cx(qs[:2], qs[2:4])
19+
squin.cx(qs[3], qs[4])
20+
squin.broadcast.cx(qs[:5], qs[5:])
2321

24-
if num_alloc > 0:
25-
new_qubits = qubit.qalloc(num_alloc)
26-
squin.broadcast.cx(reg[-num_alloc:], new_qubits)
27-
reg = reg + new_qubits
22+
return gemini_logical.terminal_measure(qs)
2823

2924

30-
log_depth_ghz.print()
25+
ghz_optimal.print()
3126

32-
log_depth_ghz = SquinToNative().emit(log_depth_ghz)
33-
log_depth_ghz.print()
27+
ghz_optimal = SquinToNative().emit(ghz_optimal)
28+
ghz_optimal.print()
3429

35-
log_depth_ghz = NativeToCircuit().emit(log_depth_ghz)
36-
log_depth_ghz.print()
30+
ghz_optimal = NativeToPlace().emit(ghz_optimal)
31+
ghz_optimal.print()
3732

38-
log_depth_ghz = CircuitToMove(
33+
ghz_optimal = PlaceToMove(
3934
fixed.LogicalLayoutHeuristic(),
4035
fixed.LogicalPlacementStrategy(),
4136
fixed.LogicalMoveScheduler(),
42-
).emit(log_depth_ghz)
43-
log_depth_ghz.print()
44-
45-
out = SpecializeGemini().emit(log_depth_ghz)
46-
out.print()
37+
).emit(ghz_optimal)
38+
ghz_optimal.print()

src/bloqade/lanes/analysis/placement/analysis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def sq_placements(
4040
) -> AtomState:
4141
pass
4242

43+
def measure_placements(
44+
self,
45+
state: AtomState,
46+
qubits: tuple[int, ...],
47+
) -> AtomState:
48+
return state
49+
4350

4451
@dataclass
4552
class PlacementAnalysis(Forward[AtomState]):
Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from bloqade.lanes.analysis.placement import AtomState, ConcreteState, PlacementAnalysis
1111
from bloqade.lanes.types import StateType
1212

13-
dialect = ir.Dialect(name="lanes.circuit")
13+
dialect = ir.Dialect(name="lanes.place")
1414

1515

1616
@statement(init=False)
@@ -96,13 +96,13 @@ def __init__(self, final_state: ir.SSAValue, *classical_results: ir.SSAValue):
9696

9797

9898
@statement(dialect=dialect)
99-
class StaticCircuit(ir.Statement):
99+
class StaticPlacement(ir.Statement):
100100
"""This statement represents A static circuit to be executed on the hardware.
101101
102102
The body region contains the low-level instructions to be executed.
103103
The inputs are the squin qubits to be used in the execution.
104104
105-
The region always terminates with an ExitRegion statement, which provides the
105+
The region always terminates with an Yield statement, which provides the
106106
the measurement results for the qubits depending on which low-level code was executed.
107107
"""
108108

@@ -176,26 +176,43 @@ def impl_single_qubit_gate(
176176
def impl_yield(
177177
self, _interp: PlacementAnalysis, frame: ForwardFrame[AtomState], stmt: Yield
178178
):
179-
return interp.YieldValue((frame.get(stmt.final_state),))
179+
return interp.YieldValue(frame.get_values(stmt.args))
180180

181-
@interp.impl(StaticCircuit)
181+
@interp.impl(StaticPlacement)
182182
def impl_static_circuit(
183183
self,
184184
_interp: PlacementAnalysis,
185185
frame: ForwardFrame[AtomState],
186-
stmt: StaticCircuit,
186+
stmt: StaticPlacement,
187187
):
188188
initial_state = _interp.get_inintial_state(stmt.qubits)
189189

190-
final_state = _interp.frame_call_region(frame, stmt, stmt.body, initial_state)
190+
frame_call_result = _interp.frame_call_region(
191+
frame, stmt, stmt.body, initial_state
192+
)
191193

192-
ret = (AtomState.bottom(),) * len(stmt.results)
194+
match frame_call_result:
195+
case (ConcreteState() as final_state, *ret):
196+
for qid, qubit in enumerate(stmt.qubits):
197+
_interp.move_count[qubit] += final_state.move_count[qid]
193198

194-
if isinstance(final_state, ConcreteState):
195-
for qid, qubit in enumerate(stmt.qubits):
196-
_interp.move_count[qubit] += final_state.move_count[qid]
199+
return tuple(ret)
200+
case _:
201+
raise interp.InterpreterError(
202+
"StaticPlacement body did not return a ConcreteState"
203+
)
197204

198-
return ret
205+
@interp.impl(EndMeasure)
206+
def end_measure(
207+
self,
208+
_interp: PlacementAnalysis,
209+
frame: ForwardFrame[AtomState],
210+
stmt: EndMeasure,
211+
):
212+
new_state = _interp.placement_strategy.measure_placements(
213+
frame.get(stmt.state_before), stmt.qubits
214+
)
215+
return (new_state,) + (EmptyLattice.bottom(),) * len(stmt.qubits)
199216

200217

201218
@dialect.register(key="circuit.layout")
@@ -212,12 +229,12 @@ def cz(
212229

213230
return ()
214231

215-
@interp.impl(StaticCircuit)
232+
@interp.impl(StaticPlacement)
216233
def static_circuit(
217234
self,
218235
_interp: LayoutAnalysis,
219236
frame: ForwardFrame[EmptyLattice],
220-
stmt: StaticCircuit,
237+
stmt: StaticPlacement,
221238
):
222239
initial_addresses = tuple(
223240
_interp.address_entries[qubit] for qubit in stmt.qubits

src/bloqade/lanes/heuristics/fixed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SiteLaneAddress,
1616
WordLaneAddress,
1717
)
18-
from bloqade.lanes.rewrite.circuit2move import MoveSchedulerABC
18+
from bloqade.lanes.rewrite.place2move import MoveSchedulerABC
1919

2020

2121
@dataclass

src/bloqade/lanes/rewrite/native2circuit.py renamed to src/bloqade/lanes/rewrite/native2place.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from kirin.dialects import ilist
88
from kirin.rewrite import abc
99

10-
from bloqade.lanes.dialects import circuit
10+
from bloqade.lanes.dialects import place
1111
from bloqade.lanes.types import StateType
1212

1313

1414
@dataclass
15-
class RewriteLowLevelCircuit(abc.RewriteRule):
15+
class RewritePlaceOperations(abc.RewriteRule):
1616
"""
17-
Rewrite rule to convert native operations to placement operations.
17+
Rewrite rule to convert native operations to place operations.
1818
This is a placeholder for the actual implementation.
1919
"""
2020

@@ -33,18 +33,18 @@ def prep_region(self) -> tuple[ir.Region, ir.Block, ir.SSAValue]:
3333

3434
def construct_execute(
3535
self,
36-
quantum_stmt: circuit.QuantumStmt,
36+
quantum_stmt: place.QuantumStmt,
3737
*,
3838
qubits: tuple[ir.SSAValue, ...],
3939
body: ir.Region,
4040
block: ir.Block,
41-
) -> circuit.StaticCircuit:
41+
) -> place.StaticPlacement:
4242
block.stmts.append(quantum_stmt)
4343
block.stmts.append(
44-
circuit.Yield(quantum_stmt.state_after, *quantum_stmt.results[1:])
44+
place.Yield(quantum_stmt.state_after, *quantum_stmt.results[1:])
4545
)
4646

47-
return circuit.StaticCircuit(qubits=qubits, body=body)
47+
return place.StaticPlacement(qubits=qubits, body=body)
4848

4949
def rewrite_TerminalLogicalMeasurement(
5050
self, node: gemini_stmts.TerminalLogicalMeasurement
@@ -54,7 +54,7 @@ def rewrite_TerminalLogicalMeasurement(
5454

5555
inputs = args_list.values
5656
body, block, entry_state = self.prep_region()
57-
gate_stmt = circuit.EndMeasure(
57+
gate_stmt = place.EndMeasure(
5858
entry_state,
5959
qubits=tuple(range(len(inputs))),
6060
)
@@ -63,7 +63,7 @@ def rewrite_TerminalLogicalMeasurement(
6363
)
6464
new_node.insert_before(node)
6565
node.replace_by(
66-
circuit.ConvertToPhysicalMeasurements(
66+
place.ConvertToPhysicalMeasurements(
6767
tuple(new_node.results),
6868
)
6969
)
@@ -85,7 +85,7 @@ def rewrite_CZ(self, node: gate.CZ) -> abc.RewriteResult:
8585
n_controls = len(controls)
8686

8787
body, block, entry_state = self.prep_region()
88-
stmt = circuit.CZ(
88+
stmt = place.CZ(
8989
entry_state,
9090
controls=all_qubits[:n_controls],
9191
targets=all_qubits[n_controls:],
@@ -106,7 +106,7 @@ def rewrite_R(self, node: gate.R) -> abc.RewriteResult:
106106
inputs = args_list.values
107107

108108
body, block, entry_state = self.prep_region()
109-
gate_stmt = circuit.R(
109+
gate_stmt = place.R(
110110
entry_state,
111111
qubits=tuple(range(len(inputs))),
112112
axis_angle=node.axis_angle,
@@ -127,7 +127,7 @@ def rewrite_Rz(self, node: gate.Rz) -> abc.RewriteResult:
127127
body = ir.Region(block := ir.Block())
128128
entry_state = block.args.append_from(StateType, name="entry_state")
129129

130-
gate_stmt = circuit.Rz(
130+
gate_stmt = place.Rz(
131131
entry_state,
132132
qubits=tuple(range(len(inputs))),
133133
rotation_angle=node.rotation_angle,
@@ -158,11 +158,11 @@ class MergePlacementRegions(abc.RewriteRule):
158158
def remap_qubits(
159159
self,
160160
curr_state: ir.SSAValue,
161-
node: circuit.R | circuit.Rz | circuit.CZ | circuit.EndMeasure,
161+
node: place.R | place.Rz | place.CZ | place.EndMeasure,
162162
input_map: dict[int, int],
163163
):
164-
if isinstance(node, circuit.CZ):
165-
return circuit.CZ(
164+
if isinstance(node, place.CZ):
165+
return place.CZ(
166166
curr_state,
167167
targets=tuple(input_map[i] for i in node.targets),
168168
controls=tuple(input_map[i] for i in node.controls),
@@ -178,8 +178,8 @@ def remap_qubits(
178178

179179
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
180180
if not (
181-
isinstance(node, circuit.StaticCircuit)
182-
and isinstance(next_node := node.next_stmt, circuit.StaticCircuit)
181+
isinstance(node, place.StaticPlacement)
182+
and isinstance(next_node := node.next_stmt, place.StaticPlacement)
183183
):
184184
return abc.RewriteResult()
185185

@@ -199,16 +199,14 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
199199
new_block = new_body.blocks[0]
200200

201201
curr_yield = new_block.last_stmt
202-
assert isinstance(curr_yield, circuit.Yield)
202+
assert isinstance(curr_yield, place.Yield)
203203

204204
curr_state = curr_yield.final_state
205205
current_yields = list(curr_yield.classical_results)
206206
curr_yield.delete()
207207

208208
for stmt in next_node.body.blocks[0].stmts:
209-
if isinstance(
210-
stmt, (circuit.R, circuit.Rz, circuit.CZ, circuit.EndMeasure)
211-
):
209+
if isinstance(stmt, (place.R, place.Rz, place.CZ, place.EndMeasure)):
212210
remapped_stmt = self.remap_qubits(curr_state, stmt, new_input_map)
213211
curr_state = remapped_stmt.results[0]
214212
new_block.stmts.append(remapped_stmt)
@@ -218,14 +216,14 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
218216
old_result.replace_by(new_result)
219217
current_yields.append(new_result)
220218

221-
new_yield = circuit.Yield(
219+
new_yield = place.Yield(
222220
curr_state,
223221
*current_yields,
224222
)
225223
new_block.stmts.append(new_yield)
226224

227225
# create the new static circuit
228-
new_static_circuit = circuit.StaticCircuit(
226+
new_static_circuit = place.StaticPlacement(
229227
new_qubits,
230228
new_body,
231229
)

0 commit comments

Comments
 (0)