-
Notifications
You must be signed in to change notification settings - Fork 1
Implement pyqrack interpreter methods for squin dialect #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 48 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
1b6d42d
Fix identity wrapper argument
david-pl eed5a98
Fix measure wrapper typing
david-pl ecc641f
Implement methods for qubit statements
david-pl 2b9da28
Start implementing operator runtime
david-pl 1dc73e5
Change operator runtime implementation
david-pl 61df76a
Implement mult runtime -- TODO: don't wrap mult
david-pl b475a76
Implement kron runtime
david-pl a750374
Scale runtime
david-pl 3153f5b
Rework impl of control to work with all supported operators
david-pl 22dbc74
Runtime for phase and shift operators
david-pl cb0f0a6
Implement Sp/Sn runtime
david-pl d77422a
Factor 1/2 in Sn and Sp
david-pl fd08501
Fix method name for adjoints
david-pl c328d5e
Implement simple rotations about X, Y, Z
david-pl 8727acc
Fix test and bug
david-pl e86c267
Implement (somewhat strange) control apply for Krons
david-pl c99bb4d
Runtime for broadcast
david-pl 3280b47
Implement runtime for U3
david-pl b358747
Implement runtime for CliffordString
david-pl ca94e55
Draft implementation for wire dialect
david-pl 1d7bfbd
Check whether qubits are active before applying operator
david-pl f1e6efd
Check is_active in multiple places
david-pl 0b99977
Fix phase runtime and test
david-pl 7906068
Fix measure impl
david-pl fc6e9f8
Impl for MeasureAny stmt
david-pl b5c4fba
Raise InterpreterError instead of RuntimeError
david-pl a815901
Fix wrong enum value for axis
david-pl b6e7dbf
Add test that would have caught the wrong enum value
david-pl 5f4f145
"Fix" CI by marking wired as xfail
david-pl b774433
Add tests for projectors
david-pl a872f41
Test actual adjoint code path
david-pl b559f84
Improve test for control and add test with control(adjoint)
david-pl f464c6c
Merge branch 'main' into david/185-pyqrack-squin
weinbe58 e445064
fixing broken tests from merging main
weinbe58 c77c205
cast qubit measure to bool
david-pl 6dd413c
Update apply signature for single qubits
david-pl 69b42c5
Address PR comments
david-pl d298078
MeasuremeQubitList returns IList
david-pl a35fb61
Properly check for atom loss in apply methods
david-pl fffd492
Fix tests
david-pl f5eaa8e
Fix typing of measure_and_reset
david-pl fba5a2f
Properly account for operator size in runtime
david-pl e488852
Properly deal with nested adjoints
david-pl 59f52ab
Fix broadcasting for operators with size larger 1
david-pl 636637f
Add a test for CXX gate
david-pl 711e484
Rename n_qubits to n_sites
david-pl df4160f
Fix applied factor in scale
david-pl 96b298c
Add some more info to docstrings
david-pl 0b93fd0
Remove MeasureAny impl
david-pl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| from kirin import interp | ||
|
|
||
| from bloqade.squin import op | ||
| from bloqade.pyqrack.base import PyQrackInterpreter | ||
|
|
||
| from .runtime import ( | ||
| SnRuntime, | ||
| SpRuntime, | ||
| U3Runtime, | ||
| RotRuntime, | ||
| KronRuntime, | ||
| MultRuntime, | ||
| ScaleRuntime, | ||
| AdjointRuntime, | ||
| ControlRuntime, | ||
| PhaseOpRuntime, | ||
| IdentityRuntime, | ||
| OperatorRuntime, | ||
| ProjectorRuntime, | ||
| OperatorRuntimeABC, | ||
| PauliStringRuntime, | ||
| ) | ||
|
|
||
|
|
||
| @op.dialect.register(key="pyqrack") | ||
| class PyQrackMethods(interp.MethodTable): | ||
|
|
||
| @interp.impl(op.stmts.Kron) | ||
| def kron( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Kron | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| lhs = frame.get(stmt.lhs) | ||
| rhs = frame.get(stmt.rhs) | ||
| return (KronRuntime(lhs, rhs),) | ||
|
|
||
| @interp.impl(op.stmts.Mult) | ||
| def mult( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Mult | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| lhs = frame.get(stmt.lhs) | ||
| rhs = frame.get(stmt.rhs) | ||
| return (MultRuntime(lhs, rhs),) | ||
|
|
||
| @interp.impl(op.stmts.Adjoint) | ||
| def adjoint( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Adjoint | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| op = frame.get(stmt.op) | ||
| return (AdjointRuntime(op),) | ||
|
|
||
| @interp.impl(op.stmts.Scale) | ||
| def scale( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Scale | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| op = frame.get(stmt.op) | ||
| factor = frame.get(stmt.factor) | ||
| return (ScaleRuntime(op, factor),) | ||
|
|
||
| @interp.impl(op.stmts.Control) | ||
| def control( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Control | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| op = frame.get(stmt.op) | ||
| n_controls = stmt.n_controls | ||
| rt = ControlRuntime( | ||
| op=op, | ||
| n_controls=n_controls, | ||
| ) | ||
| return (rt,) | ||
|
|
||
| @interp.impl(op.stmts.Rot) | ||
| def rot( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Rot | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| axis = frame.get(stmt.axis) | ||
| angle = frame.get(stmt.angle) | ||
| return (RotRuntime(axis, angle),) | ||
|
|
||
| @interp.impl(op.stmts.Identity) | ||
| def identity( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| return (IdentityRuntime(sites=stmt.sites),) | ||
|
|
||
| @interp.impl(op.stmts.PhaseOp) | ||
| @interp.impl(op.stmts.ShiftOp) | ||
| def phaseop( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: op.stmts.PhaseOp | op.stmts.ShiftOp, | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| theta = frame.get(stmt.theta) | ||
| global_ = isinstance(stmt, op.stmts.PhaseOp) | ||
| return (PhaseOpRuntime(theta, global_=global_),) | ||
|
|
||
| @interp.impl(op.stmts.X) | ||
| @interp.impl(op.stmts.Y) | ||
| @interp.impl(op.stmts.Z) | ||
| @interp.impl(op.stmts.H) | ||
| @interp.impl(op.stmts.S) | ||
| @interp.impl(op.stmts.T) | ||
| def operator( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: ( | ||
| op.stmts.X | op.stmts.Y | op.stmts.Z | op.stmts.H | op.stmts.S | op.stmts.T | ||
| ), | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| return (OperatorRuntime(method_name=stmt.name.lower()),) | ||
|
|
||
| @interp.impl(op.stmts.P0) | ||
| @interp.impl(op.stmts.P1) | ||
| def projector( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: op.stmts.P0 | op.stmts.P1, | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| state = isinstance(stmt, op.stmts.P1) | ||
| return (ProjectorRuntime(to_state=state),) | ||
|
|
||
| @interp.impl(op.stmts.Sp) | ||
| def sp( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sp | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| return (SpRuntime(),) | ||
|
|
||
| @interp.impl(op.stmts.Sn) | ||
| def sn( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Sn | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| return (SnRuntime(),) | ||
|
|
||
| @interp.impl(op.stmts.U3) | ||
| def u3( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.U3 | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| theta = frame.get(stmt.theta) | ||
| phi = frame.get(stmt.phi) | ||
| lam = frame.get(stmt.lam) | ||
| return (U3Runtime(theta, phi, lam),) | ||
|
|
||
| @interp.impl(op.stmts.PauliString) | ||
| def clifford_string( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: op.stmts.PauliString, | ||
| ) -> tuple[OperatorRuntimeABC]: | ||
| string = stmt.string | ||
| ops = [OperatorRuntime(method_name=name.lower()) for name in stmt.string] | ||
| return (PauliStringRuntime(string, ops),) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| from typing import Any | ||
|
|
||
| from kirin import interp | ||
| from kirin.dialects import ilist | ||
| from kirin.interp.exceptions import InterpreterError | ||
|
|
||
| from bloqade.squin import qubit | ||
| from bloqade.pyqrack.reg import QubitState, PyQrackQubit | ||
| from bloqade.pyqrack.base import PyQrackInterpreter | ||
|
|
||
| from .runtime import OperatorRuntimeABC | ||
|
|
||
|
|
||
| @qubit.dialect.register(key="pyqrack") | ||
| class PyQrackMethods(interp.MethodTable): | ||
| @interp.impl(qubit.New) | ||
| def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New): | ||
| n_qubits: int = frame.get(stmt.n_qubits) | ||
| qreg = ilist.IList( | ||
| [ | ||
| PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active) | ||
| for i in interp.memory.allocate(n_qubits=n_qubits) | ||
| ] | ||
| ) | ||
| return (qreg,) | ||
|
|
||
| @interp.impl(qubit.Apply) | ||
| def apply(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Apply): | ||
| qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) | ||
| operator: OperatorRuntimeABC = frame.get(stmt.operator) | ||
| operator.apply(*qubits) | ||
|
|
||
| @interp.impl(qubit.Broadcast) | ||
| def broadcast( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Broadcast | ||
| ): | ||
| operator: OperatorRuntimeABC = frame.get(stmt.operator) | ||
| qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) | ||
| operator.broadcast_apply(qubits) | ||
|
|
||
| def _measure_qubit(self, qbit: PyQrackQubit): | ||
| if qbit.is_active(): | ||
| return bool(qbit.sim_reg.m(qbit.addr)) | ||
|
|
||
| @interp.impl(qubit.MeasureQubitList) | ||
| def measure_qubit_list( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: qubit.MeasureQubitList, | ||
| ): | ||
| qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) | ||
| result = ilist.IList([self._measure_qubit(qbit) for qbit in qubits]) | ||
| return (result,) | ||
|
|
||
| @interp.impl(qubit.MeasureQubit) | ||
| def measure_qubit( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit | ||
| ): | ||
| qbit: PyQrackQubit = frame.get(stmt.qubit) | ||
| result = self._measure_qubit(qbit) | ||
| return (result,) | ||
|
|
||
| @interp.impl(qubit.MeasureAny) | ||
| def measure_any( | ||
| self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureAny | ||
| ): | ||
| input = frame.get(stmt.input) | ||
|
|
||
| if isinstance(input, PyQrackQubit) and input.is_active(): | ||
| result = self._measure_qubit | ||
| elif isinstance(input, ilist.IList): | ||
| result = [] | ||
| for qbit in input: | ||
| if not isinstance(qbit, PyQrackQubit): | ||
| raise InterpreterError(f"Cannot measure {type(qbit).__name__}") | ||
|
|
||
| result.append(self._measure_qubit(qbit)) | ||
| else: | ||
| raise InterpreterError(f"Cannot measure {type(input).__name__}") | ||
|
|
||
| return (result,) | ||
|
|
||
| @interp.impl(qubit.MeasureAndReset) | ||
| def measure_and_reset( | ||
| self, | ||
| interp: PyQrackInterpreter, | ||
| frame: interp.Frame, | ||
| stmt: qubit.MeasureAndReset, | ||
| ): | ||
| qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) | ||
| result = [] | ||
| for qbit in qubits: | ||
| if qbit.is_active(): | ||
| result.append(qbit.sim_reg.m(qbit.addr)) | ||
| else: | ||
| result.append(None) | ||
| qbit.sim_reg.force_m(qbit.addr, 0) | ||
|
|
||
| return (ilist.IList(result),) | ||
|
|
||
| @interp.impl(qubit.Reset) | ||
| def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset): | ||
| qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) | ||
| for qbit in qubits: | ||
| qbit.sim_reg.force_m(qbit.addr, 0) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is only for syntax sugar not for runtime. can you remove this?