-
Notifications
You must be signed in to change notification settings - Fork 1
TerminalLogicalMeasurement Validation #641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
0c3c8f1
initial TerminalMeasurement check
johnzl-777 6dcf451
add invoke support
johnzl-777 46e3db7
Merge branch 'main' into 610-terminalmeasure-validation
johnzl-777 fa61fb4
fix incorrect keys type
johnzl-777 6aed711
use a boolean instead of counting, also have nicer implementation for…
johnzl-777 e58ca58
be more careful around getting measurement results
johnzl-777 d14fa56
just inherit from address analysis method tabl
johnzl-777 d65b7a2
pass on address analysis interpreter info instead of sifting through …
johnzl-777 699730f
Merge branch 'main' into 610-terminalmeasure-validation
weinbe58 9dfacf0
get rid of unnecessary assertion on measure lattice element length
johnzl-777 8fc4860
add validation error if measurement id analysis fails
johnzl-777 61e1438
add validation error if measurement id analysis fails
johnzl-777 76a73cd
add terminal measurement validation to verify
johnzl-777 dc6ffdd
get tests to work with added terminal measurement validation in kerne…
johnzl-777 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
src/bloqade/gemini/analysis/measurement_validation/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from . import impls as impls, analysis as analysis # NOTE: register methods | ||
| from .analysis import ( | ||
| GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation, | ||
| ) |
66 changes: 66 additions & 0 deletions
66
src/bloqade/gemini/analysis/measurement_validation/analysis.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from typing import Any | ||
| from dataclasses import field, dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.lattice import EmptyLattice | ||
| from kirin.analysis import Forward, ForwardFrame | ||
| from kirin.validation import ValidationPass | ||
|
|
||
| from bloqade.analysis import address, measure_id | ||
|
|
||
|
|
||
| @dataclass | ||
| class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]): | ||
| keys = ("gemini.validate.terminal_measurement",) | ||
|
|
||
| measurement_analysis_results: ForwardFrame | ||
| unique_qubits_allocated: int = 0 | ||
| terminal_measurement_encountered: bool = False | ||
| lattice = EmptyLattice | ||
|
|
||
| # boilerplate, not really worried about these right now | ||
| def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): | ||
| return tuple(self.lattice.bottom() for _ in range(len(node.results))) | ||
|
|
||
| def method_self(self, method: ir.Method) -> EmptyLattice: | ||
| return self.lattice.bottom() | ||
|
|
||
|
|
||
| @dataclass | ||
| class GeminiTerminalMeasurementValidation(ValidationPass): | ||
|
|
||
| analysis_cache: dict = field(default_factory=dict) | ||
|
|
||
| def name(self) -> str: | ||
| return "Gemini Terminal Measurement Validation" | ||
|
|
||
| def get_required_analyses(self) -> list[type]: | ||
| return [measure_id.MeasurementIDAnalysis] | ||
|
|
||
| def set_analysis_cache(self, cache: dict[type, Any]) -> None: | ||
| self.analysis_cache.update(cache) | ||
|
|
||
| def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]: | ||
|
|
||
| # get the data out of the cache and forward it to the underlying analysis | ||
| measurement_analysis_results = self.analysis_cache.get( | ||
| measure_id.MeasurementIDAnalysis | ||
| ) | ||
|
|
||
| address_analysis = address.AddressAnalysis(dialects=method.dialects) | ||
| address_analysis.run(method) | ||
| unique_qubits_allocated = address_analysis.qubit_count | ||
|
|
||
| assert ( | ||
| measurement_analysis_results is not None | ||
| ), "Measurement ID analysis results not found in cache" | ||
|
|
||
| analysis = _GeminiTerminalMeasurementValidationAnalysis( | ||
| method.dialects, | ||
| measurement_analysis_results, | ||
| unique_qubits_allocated=unique_qubits_allocated, | ||
| ) | ||
|
|
||
| frame, _ = analysis.run(method) | ||
|
|
||
| return frame, analysis.get_validation_errors() |
88 changes: 88 additions & 0 deletions
88
src/bloqade/gemini/analysis/measurement_validation/impls.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| from kirin import ir, interp as _interp | ||
| from kirin.analysis import ForwardFrame | ||
| from kirin.dialects import func | ||
|
|
||
| from bloqade import qubit, gemini | ||
| from bloqade.analysis.address.impls import Func as AddressFuncMethodTable | ||
| from bloqade.analysis.measure_id.lattice import MeasureIdTuple | ||
|
|
||
| from .analysis import _GeminiTerminalMeasurementValidationAnalysis | ||
|
|
||
|
|
||
| @qubit.dialect.register(key="gemini.validate.terminal_measurement") | ||
| class __QubitGeminiMeasurementValidation(_interp.MethodTable): | ||
|
|
||
| # This is a non-logical measurement, can safely flag as invalid | ||
| @_interp.impl(qubit.stmts.Measure) | ||
| def measure( | ||
| self, | ||
| interp: _GeminiTerminalMeasurementValidationAnalysis, | ||
| frame: ForwardFrame, | ||
| stmt: qubit.stmts.Measure, | ||
| ): | ||
|
|
||
| interp.add_validation_error( | ||
| stmt, | ||
| ir.ValidationError( | ||
| stmt, | ||
| "Non-terminal measurements are not allowed in Gemini programs!", | ||
| ), | ||
| ) | ||
|
|
||
| return (interp.lattice.bottom(),) | ||
|
|
||
|
|
||
| @gemini.logical.dialect.register(key="gemini.validate.terminal_measurement") | ||
| class __GeminiLogicalMeasurementValidation(_interp.MethodTable): | ||
|
|
||
| # This is a logical terminal measurement, which is allowed | ||
| # but we impose the following restrictions: | ||
| # 1. All qubits spawned MUST be consumed | ||
| @_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement) | ||
| def terminal_measure( | ||
| self, | ||
| interp: _GeminiTerminalMeasurementValidationAnalysis, | ||
| frame: ForwardFrame, | ||
| stmt: gemini.logical.stmts.TerminalLogicalMeasurement, | ||
| ): | ||
|
|
||
| # should only be one terminal measurement EVER | ||
| if not interp.terminal_measurement_encountered: | ||
| interp.terminal_measurement_encountered = True | ||
| else: | ||
| interp.add_validation_error( | ||
| stmt, | ||
| ir.ValidationError( | ||
| stmt, | ||
| "Multiple terminal measurements are not allowed in Gemini logical programs!", | ||
| ), | ||
| ) | ||
| return (interp.lattice.bottom(),) | ||
|
|
||
| measurement_analysis_results = interp.measurement_analysis_results | ||
| total_qubits_allocated = interp.unique_qubits_allocated | ||
|
|
||
| # could make these proper exceptions but would be tricky to communicate to user | ||
| # without revealing under-the-hood details | ||
| measure_lattice_element = measurement_analysis_results.get_values(stmt.results) | ||
| assert len(measure_lattice_element) == 1 | ||
| measure_lattice_element = measure_lattice_element[0] | ||
johnzl-777 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert isinstance(measure_lattice_element, MeasureIdTuple) | ||
johnzl-777 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if len(measure_lattice_element.data) != total_qubits_allocated: | ||
| interp.add_validation_error( | ||
| stmt, | ||
| ir.ValidationError( | ||
| stmt, | ||
| "The number of qubits in the terminal measurement does not match the number of total qubits allocated! " | ||
| + f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.", | ||
| ), | ||
| ) | ||
| return (interp.lattice.bottom(),) | ||
|
|
||
| return (interp.lattice.bottom(),) | ||
|
|
||
|
|
||
| @func.dialect.register(key="gemini.validate.terminal_measurement") | ||
| class Func(AddressFuncMethodTable): | ||
| pass | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix the circular import put this validation inside
logicalfolder since this is only a logical program validation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By extension that would also mean that the current
logical_validationshould also go in there as well? Just as a heads up I'll be splitting up the rules inlogical_validationright after this PR (so the scf, gate, and func impls are separate validations you can compose)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gave it a shot, didn't play nice ): Willing to revisit in a subsequent PR though, that will involve shuffling more things around anyways