-
Notifications
You must be signed in to change notification settings - Fork 1
Squin Operator Sites Analysis #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d9c7bba
initial shape lattice
johnzl-777 13135a3
more analysis work
johnzl-777 9aa2964
getting there
johnzl-777 9ae811c
proper handling of Sized(1) trait operators
johnzl-777 0c20b95
Completed implementation but not quite working
johnzl-777 5e1c221
debugging empty registry problem
johnzl-777 d765f09
fixed missing registry + improper impl returns
johnzl-777 a12dae1
remove incorrect comment
johnzl-777 a1b0da0
remove incorrect comment
johnzl-777 32cc734
remove lingering print
johnzl-777 7af6fb0
remove incorrect comment
johnzl-777 accb460
rename to use sites instead of shape
johnzl-777 c0a0593
add unit tests, implement name changes suggested by Roger
johnzl-777 685e1c9
fix improper kron test
johnzl-777 22fc9e7
complete primitive op test
johnzl-777 39fdc84
Rename trait and type in lattice to avoid collision
johnzl-777 6f7c4b4
Try to choose trait names to avoid collision + be clear
johnzl-777 2da07d3
Remove playground file
johnzl-777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from kirin import ir, types | ||
| from kirin.passes import Fold | ||
| from kirin.dialects import py, func | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.squin.analysis import nsites | ||
|
|
||
|
|
||
| def as_int(value: int): | ||
| return py.constant.Constant(value=value) | ||
|
|
||
|
|
||
| squin_with_qasm_core = squin.groups.wired.add(py) | ||
|
|
||
| stmts: list[ir.Statement] = [ | ||
| (h0 := squin.op.stmts.H()), | ||
| (h1 := squin.op.stmts.H()), | ||
| (hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)), | ||
| (chh := squin.op.stmts.Control(hh.result, n_controls=1)), | ||
| (factor := as_int(1)), | ||
| (schh := squin.op.stmts.Scale(chh.result, factor=factor.result)), | ||
| (func.Return(schh.result)), | ||
| ] | ||
|
|
||
| block = ir.Block(stmts) | ||
| block.args.append_from(types.MethodType[[], types.NoneType], "main_self") | ||
| func_wrapper = func.Function( | ||
| sym_name="main", | ||
| signature=func.Signature(inputs=(), output=squin.op.types.OpType), | ||
| body=ir.Region(blocks=block), | ||
| ) | ||
|
|
||
| constructed_method = ir.Method( | ||
| mod=None, | ||
| py_func=None, | ||
| sym_name="main", | ||
| dialects=squin_with_qasm_core, | ||
| code=func_wrapper, | ||
| arg_names=[], | ||
| ) | ||
|
|
||
| fold_pass = Fold(squin_with_qasm_core) | ||
| fold_pass(constructed_method) | ||
|
|
||
| """" | ||
| address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( | ||
| constructed_method, no_raise=False | ||
| ) | ||
|
|
||
|
|
||
| constructed_method.print(analysis=address_frame.entries) | ||
| """ | ||
|
|
||
| shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( | ||
| constructed_method, no_raise=False | ||
| ) | ||
|
|
||
|
|
||
| constructed_method.print(analysis=shape_frame.entries) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # Need this for impl registration to work properly! | ||
| from . import impls as impls | ||
| from .lattice import ( | ||
| NoSites as NoSites, | ||
| AnySites as AnySites, | ||
| NumberSites as NumberSites, | ||
| ) | ||
| from .analysis import NSitesAnalysis as NSitesAnalysis |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # from typing import cast | ||
|
|
||
| from kirin import ir | ||
| from kirin.analysis import Forward | ||
| from kirin.analysis.forward import ForwardFrame | ||
|
|
||
| from bloqade.squin.op.types import OpType | ||
| from bloqade.squin.op.traits import HasSites, FixedSites | ||
|
|
||
| from .lattice import Sites, NoSites, NumberSites | ||
|
|
||
|
|
||
| class NSitesAnalysis(Forward[Sites]): | ||
|
|
||
| keys = ["op.nsites"] | ||
| lattice = Sites | ||
|
|
||
| # Take a page from const prop in Kirin, | ||
| # I can get the data I want from the SizedTrait | ||
| # and go from there | ||
|
|
||
| ## This gets called before the registry look up | ||
| def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): | ||
| method = self.lookup_registry(frame, stmt) | ||
| if method is not None: | ||
| return method(self, frame, stmt) | ||
| elif stmt.has_trait(HasSites): | ||
| has_sites_trait = stmt.get_trait(HasSites) | ||
| sites = has_sites_trait.get_sites(stmt) | ||
| return (NumberSites(sites=sites),) | ||
| elif stmt.has_trait(FixedSites): | ||
| sites_trait = stmt.get_trait(FixedSites) | ||
| return (NumberSites(sites=sites_trait.data),) | ||
| else: | ||
| return (NoSites(),) | ||
|
|
||
| # For when no implementation is found for the statement | ||
| def eval_stmt_fallback( | ||
| self, frame: ForwardFrame[Sites], stmt: ir.Statement | ||
| ) -> tuple[Sites, ...]: # some form of Shape will go back into the frame | ||
| return tuple( | ||
| ( | ||
| self.lattice.top() | ||
| if result.type.is_subseteq(OpType) | ||
| else self.lattice.bottom() | ||
| ) | ||
| for result in stmt.results | ||
| ) | ||
|
|
||
| def run_method(self, method: ir.Method, args: tuple[Sites, ...]): | ||
| # NOTE: we do not support dynamic calls here, thus no need to propagate method object | ||
| return self.run_callable(method.code, (self.lattice.bottom(),) + args) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| from typing import cast | ||
|
|
||
| from kirin import ir, interp | ||
|
|
||
| from bloqade.squin import op | ||
|
|
||
| from .lattice import ( | ||
| NoSites, | ||
| NumberSites, | ||
| ) | ||
| from .analysis import NSitesAnalysis | ||
|
|
||
|
|
||
| @op.dialect.register(key="op.nsites") | ||
| class SquinOp(interp.MethodTable): | ||
|
|
||
| @interp.impl(op.stmts.Kron) | ||
| def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): | ||
| lhs = frame.get(stmt.lhs) | ||
| rhs = frame.get(stmt.rhs) | ||
| if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): | ||
| new_n_sites = lhs.sites + rhs.sites | ||
| return (NumberSites(sites=new_n_sites),) | ||
| else: | ||
| return (NoSites(),) | ||
|
|
||
| @interp.impl(op.stmts.Mult) | ||
| def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult): | ||
| lhs = frame.get(stmt.lhs) | ||
| rhs = frame.get(stmt.rhs) | ||
|
|
||
| if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): | ||
| lhs_sites = lhs.sites | ||
| rhs_sites = rhs.sites | ||
| # I originally considered throwing an exception here | ||
| # but Xiu-zhe (Roger) Luo has pointed out it would be | ||
| # a much better UX to add a type element that | ||
| # could explicitly indicate the error. The downside | ||
| # is you'll have some added complexity in the type lattice. | ||
| if lhs_sites != rhs_sites: | ||
| return (NoSites(),) | ||
| else: | ||
| return (NumberSites(sites=lhs_sites + rhs_sites),) | ||
| else: | ||
| return (NoSites(),) | ||
|
|
||
| @interp.impl(op.stmts.Control) | ||
| def control( | ||
| self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control | ||
| ): | ||
| op_sites = frame.get(stmt.op) | ||
|
|
||
| if isinstance(op_sites, NumberSites): | ||
| n_sites = op_sites.sites | ||
| n_controls_attr = stmt.get_attr_or_prop("n_controls") | ||
| n_controls = cast(ir.PyAttr[int], n_controls_attr).data | ||
| return (NumberSites(sites=n_sites + n_controls),) | ||
| else: | ||
| return (NoSites(),) | ||
|
|
||
| @interp.impl(op.stmts.Rot) | ||
| def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): | ||
| op_sites = frame.get(stmt.axis) | ||
| return (op_sites,) | ||
|
|
||
| @interp.impl(op.stmts.Scale) | ||
| def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): | ||
| op_sites = frame.get(stmt.op) | ||
| return (op_sites,) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from typing import final | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin.lattice import ( | ||
| SingletonMeta, | ||
| BoundedLattice, | ||
| SimpleJoinMixin, | ||
| SimpleMeetMixin, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Sites( | ||
| SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"] | ||
| ): | ||
| @classmethod | ||
| def bottom(cls) -> "Sites": | ||
| return NoSites() | ||
|
|
||
| @classmethod | ||
| def top(cls) -> "Sites": | ||
| return AnySites() | ||
|
|
||
|
|
||
| @final | ||
| @dataclass | ||
| class NoSites(Sites, metaclass=SingletonMeta): | ||
|
|
||
| def is_subseteq(self, other: Sites) -> bool: | ||
| return True | ||
|
|
||
|
|
||
| @final | ||
| @dataclass | ||
| class AnySites(Sites, metaclass=SingletonMeta): | ||
|
|
||
| def is_subseteq(self, other: Sites) -> bool: | ||
| return isinstance(other, Sites) | ||
|
|
||
|
|
||
| @final | ||
| @dataclass | ||
| class NumberSites(Sites): | ||
| sites: int | ||
|
|
||
| def is_subseteq(self, other: Sites) -> bool: | ||
| if isinstance(other, NumberSites): | ||
| return self.sites == other.sites | ||
| return False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you wanna delete the playground?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Thank you for catching that. I should probably do a local ignore for any
_playground.pyfiles moving forwardThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually use a
main.pyas my playground it's in the.gitignorealready IIRC