Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/uop/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TwoQubitCtrlGate(ir.Statement):

@statement(dialect=dialect)
class CX(TwoQubitCtrlGate):
"""Alias for the CNOT or CH gate operations."""
"""Alias for the CNOT or CX gate operations."""

name = "CX" # Note this is capitalized

Expand Down
1 change: 1 addition & 0 deletions src/bloqade/squin/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .qasm2_to_squin import QASM2ToSquin as QASM2ToSquin
57 changes: 57 additions & 0 deletions src/bloqade/squin/passes/qasm2_gate_func_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from kirin import ir, passes
from kirin.rewrite import Walk, Chain
from kirin.dialects import func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.rewrite.passes import CallGraphPass
from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule

from ..rewrite import qasm2 as qasm2_rule


class QASM2GateFuncToKirinFunc(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
from bloqade.qasm2.dialects.expr.stmts import GateFunction

if not isinstance(node, GateFunction):
return RewriteResult()

kirin_func = func.Function(
sym_name=node.sym_name,
signature=node.signature,
body=node.body,
slots=node.slots,
)
node.replace_by(kirin_func)

return RewriteResult(has_done_something=True)


class QASM2GateFuncToSquinPass(passes.Pass):

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
convert_to_kirin_func = CallGraphPass(
dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc())
)
rewrite_result = convert_to_kirin_func(mt)

combined_qasm2_rules = Walk(
Chain(
QASM2ToPyRule(),
qasm2_rule.QASM2CoreToSquin(),
qasm2_rule.QASM2GlobParallelToSquin(),
qasm2_rule.QASM2NoiseToSquin(),
qasm2_rule.QASM2IdToSquin(),
qasm2_rule.QASM2UOp1QToSquin(),
qasm2_rule.QASM2ParametrizedUOp1QToSquin(),
qasm2_rule.QASM2UOp2QToSquin(),
)
)

body_conversion_pass = CallGraphPass(
dialects=mt.dialects, rule=combined_qasm2_rules
)
rewrite_result = body_conversion_pass(mt).join(rewrite_result)

return rewrite_result
66 changes: 66 additions & 0 deletions src/bloqade/squin/passes/qasm2_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from dataclasses import dataclass

from kirin import ir
from kirin.passes import Fold, Pass, TypeInfer
from kirin.rewrite import Walk, Chain
from kirin.rewrite.abc import RewriteResult
from kirin.dialects.ilist.passes import IListDesugar

from bloqade import squin
from bloqade.squin.rewrite.qasm2 import (
QASM2IdToSquin,
QASM2CoreToSquin,
QASM2NoiseToSquin,
QASM2UOp1QToSquin,
QASM2UOp2QToSquin,
QASM2GlobParallelToSquin,
QASM2ParametrizedUOp1QToSquin,
)

# There's a QASM2Py pass that only applies an _QASM2Py rewrite rule,
# I just want the rule here.
from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule

from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass


@dataclass
class QASM2ToSquin(Pass):

def unsafe_run(self, mt: ir.Method) -> RewriteResult:

# rewrite all QASM2 to squin first
rewrite_result = Walk(
Chain(
QASM2ToPyRule(),
QASM2CoreToSquin(),
QASM2GlobParallelToSquin(),
QASM2NoiseToSquin(),
QASM2IdToSquin(),
QASM2UOp1QToSquin(),
QASM2ParametrizedUOp1QToSquin(),
QASM2UOp2QToSquin(),
)
).rewrite(mt.code)

# go into subkernels
rewrite_result = (
QASM2GateFuncToSquinPass(dialects=mt.dialects)
.unsafe_run(mt)
.join(rewrite_result)
)

# kernel should be entirely in squin dialect now
mt.dialects = squin.kernel

# the rest is taken from the squin kernel
rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt)
rewrite_result = (
TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result)
)
rewrite_result = (
IListDesugar(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result)
)
TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should track the result of type inference. @weinbe58 didn't you recently fix a bug by removing a similar .join(result) in another pass?


return rewrite_result
9 changes: 9 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin
from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin
from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin
from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin
from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin
from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin
from .parametrized_uop_1q_to_squin import (
QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin,
)
38 changes: 38 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/core_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade import squin
from bloqade.qasm2.dialects.core import stmts as core_stmts

CORE_TO_SQUIN_MAP = {
core_stmts.QRegNew: squin.qubit.qalloc,
core_stmts.Reset: squin.qubit.reset,
}


class QASM2CoreToSquin(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if isinstance(node, core_stmts.QRegGet):
py_get_item = py.GetItem(
obj=node.reg,
index=node.idx,
)
node.replace_by(py_get_item)
return RewriteResult(has_done_something=True)

if isinstance(node, core_stmts.QRegNew):
args = (node.n_qubits,)
elif isinstance(node, core_stmts.Reset):
args = (node.qarg,)
else:
return RewriteResult()

new_stmt = func.Invoke(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like you could just do this within the if above, since the args have to match the node type anyway.

callee=CORE_TO_SQUIN_MAP[type(node)],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: Why do you rewrite to stdlib calls rather than statements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my head it felt nicer to default to the stdlib versions because the output to the user would be on par with what they would see if they constructed the squin kernel through the interface we provide. Figured it makes debugging a bit nicer.

That being said, I don't have a strong preference it has to be that way and if it ends up being the case we prefer the "unrolled"/flat form as the output I can easily do that (:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's just a bit of style preference, I suppose.

In my opinion, it's nice to fall back to the stdlib whenever you have stdlib functions but no matching statement in a dialect. For example, when rewriting from squin to native, there's a lot of statements in the more general squin dialect, but matching stdlib functions in native.

Here though we basically have matching statements in squin for all statements in qasm2. The only exceptions I see are probably U1 and U2.

inputs=args,
)
node.replace_by(new_stmt)
return RewriteResult(has_done_something=True)
34 changes: 34 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from kirin import ir
from kirin.dialects import func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade import squin
from bloqade.qasm2.dialects import glob, parallel

GLOBAL_PARALLEL_TO_SQUIN_MAP = {
glob.UGate: squin.broadcast.u3,
parallel.UGate: squin.broadcast.u3,
parallel.RZ: squin.broadcast.rz,
}


class QASM2GlobParallelToSquin(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if isinstance(node, glob.UGate):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this seems a bit redundant: you might as well just assign squin_equivalent_stmt in each case here rather than using the map above.

Copy link
Contributor Author

@johnzl-777 johnzl-777 Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dictionary was something that @weinbe58 recommended, I could be doing the pattern wrong here but in cases where I do see a dictionary used it only turns out nice if the attribute you're accessing exits across all the statements.

Like for arithmetic operation conversion, you'll always have an lhs and rhs attribute. Here the number and kinds of attributes change.

I actually realize if I wanted to be clever I could add some strings to the values in the dictionary and then __getattribute__ things which would resolve the clunkiness at the expense of making things a little uglier.

Wish I could still do pattern matching but I'm told the performance would take a hit

args = (node.theta, node.phi, node.lam, node.registers)
elif isinstance(node, parallel.UGate):
args = (node.theta, node.phi, node.lam, node.qargs)
elif isinstance(node, parallel.RZ):
args = (node.theta, node.qargs)
else:
return RewriteResult()

squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)]
invoke_stmt = func.Invoke(
callee=squin_equivalent_stmt,
inputs=args,
)
node.replace_by(invoke_stmt)
return RewriteResult(has_done_something=True)
15 changes: 15 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/id_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from kirin import ir
from kirin.rewrite.abc import RewriteRule, RewriteResult

import bloqade.qasm2.dialects.uop.stmts as uop_stmts


class QASM2IdToSquin(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if not isinstance(node, uop_stmts.Id):
return RewriteResult()

node.delete()
return RewriteResult(has_done_something=True)
80 changes: 80 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/noise_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from kirin import ir
from kirin.dialects import func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade import squin
from bloqade.qasm2.dialects.noise import stmts as noise_stmts

from .util import num_to_py_constant

NOISE_TO_SQUIN_MAP = {
noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss,
noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel,
}


class QASM2NoiseToSquin(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if isinstance(node, noise_stmts.AtomLossChannel):
qargs = node.qargs
prob = node.prob
prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node)
elif isinstance(node, noise_stmts.PauliChannel):
qargs = node.qargs
p_x = node.px
p_y = node.py
p_z = node.pz
prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node)
elif isinstance(node, noise_stmts.CZPauliChannel):
return self.rewrite_CZPauliChannel(node)
else:
return RewriteResult()

squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)]
invoke_stmt = func.Invoke(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you know that you are going for the broadcast version, why not just rewrite to the statement directly instead of adding an invoke to the stdlib?

callee=squin_noise_stmt,
inputs=(*prob_ssas, qargs),
)
node.replace_by(invoke_stmt)
return RewriteResult(has_done_something=True)

def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult:

ctrls = stmt.ctrls
qargs = stmt.qargs

px_ctrl = stmt.px_ctrl
py_ctrl = stmt.py_ctrl
pz_ctrl = stmt.pz_ctrl
px_qarg = stmt.px_qarg
py_qarg = stmt.py_qarg
pz_qarg = stmt.pz_qarg

error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg]
# first half of entries for control qubits, other half for targets

error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt)

ctrl_pauli_channel_invoke = func.Invoke(
callee=squin.broadcast.single_qubit_pauli_channel,
inputs=(
*error_prob_ssas[:3],
ctrls,
),
)

qarg_pauli_channel_invoke = func.Invoke(
callee=squin.broadcast.single_qubit_pauli_channel,
inputs=(
*error_prob_ssas[3:],
qargs,
),
)

ctrl_pauli_channel_invoke.insert_before(stmt)
qarg_pauli_channel_invoke.insert_before(stmt)
stmt.delete()

return RewriteResult(has_done_something=True)
46 changes: 46 additions & 0 deletions src/bloqade/squin/rewrite/qasm2/parametrized_uop_1q_to_squin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from math import pi

from kirin import ir
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade import squin
from bloqade.qasm2.dialects.uop import stmts as uop_stmts

PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = {
uop_stmts.UGate: squin.u3,
uop_stmts.U1: squin.u3,
uop_stmts.U2: squin.u3,
uop_stmts.RZ: squin.rz,
uop_stmts.RX: squin.rx,
uop_stmts.RY: squin.ry,
}


class QASM2ParametrizedUOp1QToSquin(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)):
args = (node.theta, node.qarg)
elif isinstance(node, (uop_stmts.UGate)):
args = (node.theta, node.phi, node.lam, node.qarg)
elif isinstance(node, (uop_stmts.U1)):
zero_stmt = py.Constant(value=0.0)
zero_stmt.insert_before(node)
args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg)
elif isinstance(node, (uop_stmts.U2)):
half_pi_stmt = py.Constant(value=pi / 2)
half_pi_stmt.insert_before(node)
args = (half_pi_stmt.result, node.phi, node.lam, node.qarg)
else:
return RewriteResult()

squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)]
invoke_stmt = func.Invoke(
callee=squin_equivalent_stmt,
inputs=args,
)
node.replace_by(invoke_stmt)

return RewriteResult(has_done_something=True)
Loading
Loading