Skip to content

Commit 636083b

Browse files
committed
Update kernel validation
1 parent 9ff1d99 commit 636083b

File tree

4 files changed

+30
-18
lines changed

4 files changed

+30
-18
lines changed

src/bloqade/gemini/analysis/logical_validation/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis):
99

1010
first_gate = True
1111

12-
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
13-
if isinstance(stmt, squin.gate.stmts.Gate):
12+
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
13+
if isinstance(node, squin.gate.stmts.Gate):
1414
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
1515
self.first_gate = False
1616

17-
return super().eval_stmt_fallback(frame, stmt)
17+
return super().eval_fallback(frame, node)

src/bloqade/validation/analysis/analysis.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC):
2828

2929
lattice = ErrorType
3030

31-
def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]):
32-
return self.run_callable(method.code, (self.lattice.top(),) + args)
33-
34-
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
31+
def eval_fallback(self, frame: ValidationFrame, node: ir.Statement):
3532
# NOTE: default to no errors
36-
return tuple(self.lattice.top() for _ in stmt.results)
33+
return tuple(self.lattice.top() for _ in node.results)
3734

3835
def initialize_frame(
39-
self, code: ir.Statement, *, has_parent_access: bool = False
36+
self, node: ir.Statement, *, has_parent_access: bool = False
4037
) -> ValidationFrame:
41-
return ValidationFrame(code, has_parent_access=has_parent_access)
38+
return ValidationFrame(node, has_parent_access=has_parent_access)
39+
40+
def method_self(self, method: ir.Method) -> ErrorType:
41+
return self.lattice.top()

src/bloqade/validation/kernel_validation.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,23 @@ class KernelValidation:
4848
validation_analysis_cls: type[ValidationAnalysis]
4949
"""The analysis that you want to run in order to validate the kernel."""
5050

51-
def run(self, mt: ir.Method, **kwargs) -> None:
51+
def run(self, mt: ir.Method, no_raise: bool = True) -> None:
52+
"""Run the kernel validation analysis and raise any errors found.
53+
54+
Args:
55+
mt (ir.Method): The method to validate
56+
no_raise (bool): Whether or not to raise errors when running the analysis.
57+
This is only to make sure the analysis works. Errors found during
58+
the analysis will be raised regardless of this setting. Defaults to `True`.
59+
60+
"""
61+
5262
validation_analysis = self.validation_analysis_cls(mt.dialects)
53-
validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs)
63+
64+
if no_raise:
65+
validation_frame, _ = validation_analysis.run_no_raise(mt)
66+
else:
67+
validation_frame, _ = validation_analysis.run(mt)
5468

5569
errors = validation_frame.errors
5670

test/gemini/test_logical_validation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,14 @@ def main():
3030
if m2:
3131
squin.y(q[2])
3232

33-
frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis(
34-
main, no_raise=False
35-
)
33+
frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main)
3634

3735
main.print(analysis=frame.entries)
3836

3937
validator = KernelValidation(GeminiLogicalValidationAnalysis)
4038

4139
with pytest.raises(ValidationErrorGroup):
42-
validator.run(main)
40+
validator.run(main, no_raise=False)
4341

4442

4543
def test_for_loop():
@@ -104,8 +102,8 @@ def invalid():
104102
squin.cx(q[0], q[1])
105103
squin.u3(0.123, 0.253, 1.2, q[0])
106104

107-
frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis(
108-
invalid, no_raise=False
105+
frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
106+
invalid
109107
)
110108

111109
invalid.print(analysis=frame.entries)

0 commit comments

Comments
 (0)