Skip to content

Commit c077269

Browse files
mudit2812mehrdad2m
andauthored
Add roundtrip testing capability to filecheck fixture (#8049)
### Before submitting Please complete the following checklist when submitting a PR: - [ ] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [ ] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [ ] Ensure that the test suite passes, by running `make test`. - [ ] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [ ] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** xDSL uses round-trip lit tests for testing its dialect definitions. This involves parsing an MLIR string to xDSL, printing in generic format, and then parsing the generic format string again into xDSL before running filecheck. We add the same ability here. **Description of the Change:** * Add a `roundtrip` boolean kwarg to `run_filecheck` to allow roundtrip testing. * Add a `verify` boolean kwarg to run `op.verify()` on the xDSL module being tested. **Benefits:** More robust lit tests. **Possible Drawbacks:** **Related GitHub Issues:** [sc-97491] --------- Co-authored-by: Mehrdad Malek <[email protected]>
1 parent 33debf0 commit c077269

File tree

6 files changed

+53
-37
lines changed

6 files changed

+53
-37
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@
523523

524524
<h3>Internal changes ⚙️</h3>
525525

526+
* Add capability for roundtrip testing and module verification to the Python compiler `run_filecheck` and
527+
`run_filecheck_qjit` fixtures.
528+
[(#8049)](https://github.com/PennyLaneAI/pennylane/pull/8049)
529+
526530
* Improve type hinting internally.
527531
[(#8086)](https://github.com/PennyLaneAI/pennylane/pull/8086)
528532

pennylane/compiler/python_compiler/jax_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def __init__(
7777

7878
extra_dialects = extra_dialects or ()
7979
for dialect in self.default_dialects + tuple(extra_dialects):
80-
self.ctx.load_dialect(dialect)
80+
if self.ctx.get_optional_dialect(dialect.name) is None:
81+
self.ctx.load_dialect(dialect)
8182

8283

8384
def _module_inline(func: JaxJittedFunction, *args, **kwargs) -> jModule:

tests/python_compiler/conftest.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from xdsl.context import Context
2929
from xdsl.dialects import test
3030
from xdsl.passes import PassPipeline
31+
from xdsl.printer import Printer
3132

3233
from pennylane.compiler.python_compiler import Compiler
3334
from pennylane.compiler.python_compiler.jax_utils import (
@@ -38,7 +39,7 @@
3839
deps_available = False
3940

4041

41-
def _run_filecheck_impl(program_str, pipeline=()):
42+
def _run_filecheck_impl(program_str, pipeline=(), verify=False, roundtrip=False):
4243
"""Run filecheck on an xDSL module, comparing it to a program string containing
4344
filecheck directives."""
4445
if not deps_available:
@@ -47,13 +48,27 @@ def _run_filecheck_impl(program_str, pipeline=()):
4748
ctx = Context(allow_unregistered=True)
4849
xdsl_module = QuantumParser(ctx, program_str, extra_dialects=(test.Test,)).parse_module()
4950

51+
if roundtrip:
52+
# Print generic format
53+
stream = StringIO()
54+
Printer(stream=stream, print_generic_format=True).print_op(xdsl_module)
55+
xdsl_module = QuantumParser(ctx, stream.getvalue()).parse_module()
56+
57+
if verify:
58+
xdsl_module.verify()
59+
5060
pipeline = PassPipeline(pipeline)
5161
pipeline.apply(ctx, xdsl_module)
5262

63+
if verify:
64+
xdsl_module.verify()
65+
66+
stream = StringIO()
67+
Printer(stream).print_op(xdsl_module)
5368
opts = parse_argv_options(["filecheck", __file__])
5469
matcher = Matcher(
5570
opts,
56-
FInput("no-name", str(xdsl_module)),
71+
FInput("no-name", stream.getvalue()),
5772
Parser(opts, StringIO(program_str), *pattern_for_opts(opts)),
5873
)
5974

@@ -63,7 +78,23 @@ def _run_filecheck_impl(program_str, pipeline=()):
6378

6479
@pytest.fixture(scope="function")
6580
def run_filecheck():
66-
"""Fixture to run filecheck on an xDSL module."""
81+
"""Fixture to run filecheck on an xDSL module.
82+
83+
This fixture uses FileCheck to verify the correctness of a parsed MLIR string. Testers
84+
can provide a pass pipeline to transform the IR, and verify correctness by including
85+
FileCheck directives as comments in the input program string.
86+
87+
Args:
88+
program_str (str): The MLIR string containing the input program and FileCheck directives
89+
pipeline (tuple[ModulePass]): A sequence containing all passes that should be applied
90+
before running FileCheck
91+
verify (bool): Whether or not to verify the IR after parsing and transforming.
92+
``False`` by default.
93+
roundtrip (bool): Whether or not to use round-trip testing. This is useful for dialect
94+
tests to verify that xDSL both parses and prints the IR correctly. If ``True``, we parse
95+
the program string into an xDSL module, print it in generic format, and then parse the
96+
generic program string back to an xDSL module. ``False`` by default.
97+
"""
6798
if not deps_available:
6899
pytest.skip("Cannot run lit tests without xDSL and filecheck.")
69100

@@ -90,7 +121,7 @@ def _get_filecheck_directives(qjit_fn):
90121
return "\n".join(filecheck_directives)
91122

92123

93-
def _run_filecheck_qjit_impl(qjit_fn):
124+
def _run_filecheck_qjit_impl(qjit_fn, verify=False):
94125
"""Run filecheck on a qjit-ed function, using FileCheck directives in its inline
95126
comments to assert correctness."""
96127
if not deps_available:
@@ -107,6 +138,9 @@ def _run_filecheck_qjit_impl(qjit_fn):
107138
)
108139
xdsl_module = parse_generic_to_xdsl_module(mod_str)
109140

141+
if verify:
142+
xdsl_module.verify()
143+
110144
opts = parse_argv_options(["filecheck", __file__])
111145
matcher = Matcher(
112146
opts,
@@ -127,6 +161,11 @@ def run_filecheck_qjit():
127161
output IR against FileCheck directives that may be present in the source
128162
function as inline comments.
129163
164+
Args:
165+
qjit_fn (Callable): The QJIT object on which we want to run lit tests
166+
verify (bool): Whether or not to verify the IR after parsing and transforming.
167+
``False`` by default.
168+
130169
An example showing how to use the fixture is shown below. We apply the
131170
``merge_rotations_pass`` and check that there is only one rotation in
132171
the final IR:

tests/python_compiler/dialects/test_catalyst_dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,4 @@ def test_assembly_format(run_filecheck):
105105
%callback_result = catalyst.callback_call @callback_func(%val) : (f64) -> f64
106106
"""
107107

108-
run_filecheck(program)
108+
run_filecheck(program, roundtrip=True)

tests/python_compiler/dialects/test_mbqc_dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_assembly_format(run_filecheck):
9898
%graph_reg = mbqc.graph_state_prep (%adj_matrix : tensor<6xi1>) [init "Hadamard", entangle "CZ"] : !quantum.reg
9999
"""
100100

101-
run_filecheck(program)
101+
run_filecheck(program, roundtrip=True)
102102

103103

104104
class TestMeasureInBasisOp:

tests/python_compiler/dialects/test_quantum_dialect.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
"""Unit test module for pennylane/compiler/python_compiler/dialects/quantum.py."""
1616

17-
import io
18-
1917
import pytest
2018

2119
# pylint: disable=wrong-import-position
@@ -110,7 +108,7 @@ def test_all_attributes_names(attr):
110108
assert attr.name == expected_name
111109

112110

113-
def test_assembly_format():
111+
def test_assembly_format(run_filecheck):
114112
program = """
115113
// CHECK: quantum.alloc(1) : !quantum.reg
116114
%qreg_alloc_static = quantum.alloc(1) : !quantum.reg
@@ -172,30 +170,4 @@ def test_assembly_format():
172170
%mres2, %out_qubit2 = quantum.measure %qubit postselect 0 : i1, !quantum.bit
173171
"""
174172

175-
ctx = xdsl.context.Context()
176-
from xdsl.dialects import builtin, func, test
177-
178-
ctx.load_dialect(builtin.Builtin)
179-
ctx.load_dialect(func.Func)
180-
ctx.load_dialect(test.Test)
181-
ctx.load_dialect(Quantum)
182-
183-
module = xdsl.parser.Parser(ctx, program).parse_module()
184-
185-
from filecheck.finput import FInput
186-
from filecheck.matcher import Matcher
187-
from filecheck.options import parse_argv_options
188-
from filecheck.parser import Parser, pattern_for_opts
189-
190-
opts = parse_argv_options(["filecheck", __file__])
191-
matcher = Matcher(
192-
opts,
193-
FInput("no-name", str(module)),
194-
Parser(opts, io.StringIO(program), *pattern_for_opts(opts)),
195-
)
196-
197-
assert matcher.run() == 0
198-
199-
200-
if __name__ == "__main__":
201-
pytest.main(["-x", __file__])
173+
run_filecheck(program, roundtrip=True)

0 commit comments

Comments
 (0)