-
Notifications
You must be signed in to change notification settings - Fork 1
QASM2 to squin #644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
QASM2 to squin #644
Changes from all commits
2f6e8a1
90d8578
509b9cc
006be62
1fa9865
6552782
d4f46bb
6b9e52f
cd53e5d
59dfcb0
99e66af
045fa00
5137de0
6390a23
0645d11
df35694
3468702
5f137de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .qasm2_to_squin import QASM2ToSquin as QASM2ToSquin |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from kirin import ir, passes | ||
| from kirin.rewrite import Walk, Chain | ||
| from kirin.dialects import func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.rewrite.passes import CallGraphPass | ||
| from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule | ||
|
|
||
| from ..rewrite import qasm2 as qasm2_rule | ||
|
|
||
|
|
||
| class QASM2GateFuncToKirinFunc(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| from bloqade.qasm2.dialects.expr.stmts import GateFunction | ||
|
|
||
| if not isinstance(node, GateFunction): | ||
| return RewriteResult() | ||
|
|
||
| kirin_func = func.Function( | ||
| sym_name=node.sym_name, | ||
| signature=node.signature, | ||
| body=node.body, | ||
| slots=node.slots, | ||
| ) | ||
| node.replace_by(kirin_func) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
|
|
||
| class QASM2GateFuncToSquinPass(passes.Pass): | ||
|
|
||
| def unsafe_run(self, mt: ir.Method) -> RewriteResult: | ||
| convert_to_kirin_func = CallGraphPass( | ||
| dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc()) | ||
| ) | ||
| rewrite_result = convert_to_kirin_func(mt) | ||
|
|
||
| combined_qasm2_rules = Walk( | ||
| Chain( | ||
| QASM2ToPyRule(), | ||
| qasm2_rule.QASM2CoreToSquin(), | ||
| qasm2_rule.QASM2GlobParallelToSquin(), | ||
| qasm2_rule.QASM2NoiseToSquin(), | ||
| qasm2_rule.QASM2IdToSquin(), | ||
| qasm2_rule.QASM2UOp1QToSquin(), | ||
| qasm2_rule.QASM2ParametrizedUOp1QToSquin(), | ||
| qasm2_rule.QASM2UOp2QToSquin(), | ||
| ) | ||
| ) | ||
|
|
||
| body_conversion_pass = CallGraphPass( | ||
| dialects=mt.dialects, rule=combined_qasm2_rules | ||
| ) | ||
| rewrite_result = body_conversion_pass(mt).join(rewrite_result) | ||
|
|
||
| return rewrite_result |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.passes import Fold, Pass, TypeInfer | ||
| from kirin.rewrite import Walk, Chain | ||
| from kirin.rewrite.abc import RewriteResult | ||
| from kirin.dialects.ilist.passes import IListDesugar | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.squin.rewrite.qasm2 import ( | ||
| QASM2IdToSquin, | ||
| QASM2CoreToSquin, | ||
| QASM2NoiseToSquin, | ||
| QASM2UOp1QToSquin, | ||
| QASM2UOp2QToSquin, | ||
| QASM2GlobParallelToSquin, | ||
| QASM2ParametrizedUOp1QToSquin, | ||
| ) | ||
|
|
||
| # There's a QASM2Py pass that only applies an _QASM2Py rewrite rule, | ||
| # I just want the rule here. | ||
| from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule | ||
|
|
||
| from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass | ||
|
|
||
|
|
||
| @dataclass | ||
| class QASM2ToSquin(Pass): | ||
|
|
||
| def unsafe_run(self, mt: ir.Method) -> RewriteResult: | ||
|
|
||
| # rewrite all QASM2 to squin first | ||
| rewrite_result = Walk( | ||
| Chain( | ||
| QASM2ToPyRule(), | ||
| QASM2CoreToSquin(), | ||
| QASM2GlobParallelToSquin(), | ||
| QASM2NoiseToSquin(), | ||
| QASM2IdToSquin(), | ||
| QASM2UOp1QToSquin(), | ||
| QASM2ParametrizedUOp1QToSquin(), | ||
| QASM2UOp2QToSquin(), | ||
| ) | ||
| ).rewrite(mt.code) | ||
|
|
||
| # go into subkernels | ||
| rewrite_result = ( | ||
| QASM2GateFuncToSquinPass(dialects=mt.dialects) | ||
| .unsafe_run(mt) | ||
| .join(rewrite_result) | ||
| ) | ||
|
|
||
| # kernel should be entirely in squin dialect now | ||
| mt.dialects = squin.kernel | ||
|
|
||
| # the rest is taken from the squin kernel | ||
| rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt) | ||
| rewrite_result = ( | ||
| TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
| ) | ||
| rewrite_result = ( | ||
| IListDesugar(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
| ) | ||
| TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
|
|
||
| return rewrite_result | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin | ||
| from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin | ||
| from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin | ||
| from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin | ||
| from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin | ||
| from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin | ||
| from .parametrized_uop_1q_to_squin import ( | ||
| QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py, func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects.core import stmts as core_stmts | ||
|
|
||
| CORE_TO_SQUIN_MAP = { | ||
| core_stmts.QRegNew: squin.qubit.qalloc, | ||
| core_stmts.Reset: squin.qubit.reset, | ||
| } | ||
|
|
||
|
|
||
| class QASM2CoreToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, core_stmts.QRegGet): | ||
| py_get_item = py.GetItem( | ||
| obj=node.reg, | ||
| index=node.idx, | ||
| ) | ||
| node.replace_by(py_get_item) | ||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| if isinstance(node, core_stmts.QRegNew): | ||
| args = (node.n_qubits,) | ||
| elif isinstance(node, core_stmts.Reset): | ||
| args = (node.qarg,) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| new_stmt = func.Invoke( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like you could just do this within the |
||
| callee=CORE_TO_SQUIN_MAP[type(node)], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. General question: Why do you rewrite to stdlib calls rather than statements?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my head it felt nicer to default to the stdlib versions because the output to the user would be on par with what they would see if they constructed the squin kernel through the interface we provide. Figured it makes debugging a bit nicer. That being said, I don't have a strong preference it has to be that way and if it ends up being the case we prefer the "unrolled"/flat form as the output I can easily do that (:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. It's just a bit of style preference, I suppose. In my opinion, it's nice to fall back to the stdlib whenever you have stdlib functions but no matching statement in a dialect. For example, when rewriting from squin to native, there's a lot of statements in the more general squin dialect, but matching stdlib functions in native. Here though we basically have matching statements in squin for all statements in qasm2. The only exceptions I see are probably |
||
| inputs=args, | ||
| ) | ||
| node.replace_by(new_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects import glob, parallel | ||
|
|
||
| GLOBAL_PARALLEL_TO_SQUIN_MAP = { | ||
| glob.UGate: squin.broadcast.u3, | ||
| parallel.UGate: squin.broadcast.u3, | ||
| parallel.RZ: squin.broadcast.rz, | ||
| } | ||
|
|
||
|
|
||
| class QASM2GlobParallelToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, glob.UGate): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this seems a bit redundant: you might as well just assign
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dictionary was something that @weinbe58 recommended, I could be doing the pattern wrong here but in cases where I do see a dictionary used it only turns out nice if the attribute you're accessing exits across all the statements. Like for arithmetic operation conversion, you'll always have an I actually realize if I wanted to be clever I could add some strings to the values in the dictionary and then Wish I could still do pattern matching but I'm told the performance would take a hit |
||
| args = (node.theta, node.phi, node.lam, node.registers) | ||
| elif isinstance(node, parallel.UGate): | ||
| args = (node.theta, node.phi, node.lam, node.qargs) | ||
| elif isinstance(node, parallel.RZ): | ||
| args = (node.theta, node.qargs) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)] | ||
| invoke_stmt = func.Invoke( | ||
| callee=squin_equivalent_stmt, | ||
| inputs=args, | ||
| ) | ||
| node.replace_by(invoke_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from kirin import ir | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| import bloqade.qasm2.dialects.uop.stmts as uop_stmts | ||
|
|
||
|
|
||
| class QASM2IdToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if not isinstance(node, uop_stmts.Id): | ||
| return RewriteResult() | ||
|
|
||
| node.delete() | ||
| return RewriteResult(has_done_something=True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py, func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects.noise import stmts as noise_stmts | ||
|
|
||
| NOISE_TO_SQUIN_MAP = { | ||
| noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss, | ||
| noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel, | ||
| } | ||
|
|
||
|
|
||
| def num_to_py_constant( | ||
| values: list[int | float], stmt_to_insert_before: ir.Statement | ||
| ) -> list[ir.SSAValue]: | ||
|
|
||
| py_const_ssa_vals = [] | ||
| for value in values: | ||
| const_form = py.Constant(value=value) | ||
| const_form.insert_before(stmt_to_insert_before) | ||
| py_const_ssa_vals.append(const_form.result) | ||
|
|
||
| return py_const_ssa_vals | ||
|
|
||
|
|
||
| class QASM2NoiseToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, noise_stmts.AtomLossChannel): | ||
| qargs = node.qargs | ||
| prob = node.prob | ||
| prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node) | ||
| elif isinstance(node, noise_stmts.PauliChannel): | ||
| qargs = node.qargs | ||
| p_x = node.px | ||
| p_y = node.py | ||
| p_z = node.pz | ||
| prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node) | ||
| elif isinstance(node, noise_stmts.CZPauliChannel): | ||
| return self.rewrite_CZPauliChannel(node) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)] | ||
| invoke_stmt = func.Invoke( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you know that you are going for the broadcast version, why not just rewrite to the statement directly instead of adding an invoke to the stdlib? |
||
| callee=squin_noise_stmt, | ||
| inputs=(*prob_ssas, qargs), | ||
| ) | ||
| node.replace_by(invoke_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult: | ||
|
|
||
| ctrls = stmt.ctrls | ||
| qargs = stmt.qargs | ||
|
|
||
| px_ctrl = stmt.px_ctrl | ||
| py_ctrl = stmt.py_ctrl | ||
| pz_ctrl = stmt.pz_ctrl | ||
| px_qarg = stmt.px_qarg | ||
| py_qarg = stmt.py_qarg | ||
| pz_qarg = stmt.pz_qarg | ||
|
|
||
| error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg] | ||
| # first half of entries for control qubits, other half for targets | ||
|
|
||
| error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt) | ||
|
|
||
| ctrl_pauli_channel_invoke = func.Invoke( | ||
| callee=squin.broadcast.single_qubit_pauli_channel, | ||
| inputs=( | ||
| *error_prob_ssas[:3], | ||
| ctrls, | ||
| ), | ||
| ) | ||
|
|
||
| qarg_pauli_channel_invoke = func.Invoke( | ||
| callee=squin.broadcast.single_qubit_pauli_channel, | ||
| inputs=( | ||
| *error_prob_ssas[3:], | ||
| qargs, | ||
| ), | ||
| ) | ||
|
|
||
| ctrl_pauli_channel_invoke.insert_before(stmt) | ||
| qarg_pauli_channel_invoke.insert_before(stmt) | ||
| stmt.delete() | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| from math import pi | ||
|
|
||
| from kirin import ir | ||
| from kirin.dialects import py, func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects.uop import stmts as uop_stmts | ||
|
|
||
| PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = { | ||
| uop_stmts.UGate: squin.u3, | ||
| uop_stmts.U1: squin.u3, | ||
| uop_stmts.U2: squin.u3, | ||
| uop_stmts.RZ: squin.rz, | ||
| uop_stmts.RX: squin.rx, | ||
| uop_stmts.RY: squin.ry, | ||
| } | ||
|
|
||
|
|
||
| class QASM2ParametrizedUOp1QToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)): | ||
| args = (node.theta, node.qarg) | ||
| elif isinstance(node, (uop_stmts.UGate)): | ||
| args = (node.theta, node.phi, node.lam, node.qarg) | ||
| elif isinstance(node, (uop_stmts.U1)): | ||
| zero_stmt = py.Constant(value=0.0) | ||
| zero_stmt.insert_before(node) | ||
| args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg) | ||
| elif isinstance(node, (uop_stmts.U2)): | ||
| half_pi_stmt = py.Constant(value=pi / 2) | ||
| half_pi_stmt.insert_before(node) | ||
| args = (half_pi_stmt.result, node.phi, node.lam, node.qarg) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)] | ||
| invoke_stmt = func.Invoke( | ||
| callee=squin_equivalent_stmt, | ||
| inputs=args, | ||
| ) | ||
| node.replace_by(invoke_stmt) | ||
|
|
||
| return RewriteResult(has_done_something=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we should track the result of type inference. @weinbe58 didn't you recently fix a bug by removing a similar
.join(result)in another pass?