diff --git a/doc/catalyst-cli/catalyst-cli.rst b/doc/catalyst-cli/catalyst-cli.rst index e3ac99f1b0..dcc417788a 100644 --- a/doc/catalyst-cli/catalyst-cli.rst +++ b/doc/catalyst-cli/catalyst-cli.rst @@ -98,6 +98,12 @@ intermediate files are saved. Keep intermediate files after each pipeline in the compilation. By default, no intermediate files are saved. Using ``--keep-intermediate`` is equivalent to using ``--save-ir-after-each=pipeline``. +``--use-nameloc-as-prefix[=]`` +"""""""""""""""""""""""""""""""""""""""""" + +Print SSA IDs using their name location, if provided, as prefix. By default, name location information is not used. +Name location, or named source location, is a type of source location information that allows attaching a name to a child location. + ``--{passname}`` """"""""""""""" diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index a3a20f92d1..09bb92fe18 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -306,7 +306,7 @@ def _quantum_opt(*args, stdin=None): return _catalyst(("--tool", "opt"), *args, stdin=stdin) -def canonicalize(*args, stdin=None): +def canonicalize(*args, stdin=None, options: Optional[CompileOptions] = None): """Run opt with canonicalization echo ${stdin} | catalyst --tool=opt \ @@ -316,7 +316,11 @@ def canonicalize(*args, stdin=None): Returns stdout string """ - return _quantum_opt(("--pass-pipeline", "builtin.module(canonicalize)"), *args, stdin=stdin) + opts = ["--pass-pipeline", "builtin.module(canonicalize)"] + if options and options.use_nameloc: + opts.append("--use-nameloc-as-prefix") + + return _quantum_opt(*opts, *args, stdin=stdin) def _options_to_cli_flags(options): @@ -349,6 +353,9 @@ def _options_to_cli_flags(options): extra_args += ["--save-ir-after-each=pass"] extra_args += ["--dump-module-scope"] + if options.use_nameloc: + extra_args += ["--use-nameloc-as-prefix"] + if options.verbose: extra_args += ["--verbose"] diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 7dc1382593..cd9a84a121 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -53,12 +53,13 @@ @debug_logger -def jaxpr_to_mlir(func_name, jaxpr): +def jaxpr_to_mlir(jaxpr, func_name, arg_names): """Lower a Jaxpr into an MLIR module. Args: - func_name(str): function name jaxpr(Jaxpr): Jaxpr code to lower + func_name(str): function name + arg_names(list[str]): list of argument names Returns: module: the MLIR module corresponding to ``func`` @@ -81,6 +82,7 @@ def jaxpr_to_mlir(func_name, jaxpr): platform="cpu", axis_context=axis_context, name_stack=name_stack, + arg_names=arg_names, ) return module, context @@ -97,6 +99,7 @@ def custom_lower_jaxpr_to_module( axis_context: AxisContext, name_stack, replicated_args=None, + arg_names=None, arg_shardings=None, result_shardings=None, ): @@ -149,6 +152,7 @@ def custom_lower_jaxpr_to_module( effects, public=True, replicated_args=replicated_args, + arg_names=arg_names, arg_shardings=arg_shardings, result_shardings=result_shardings, name_stack=name_stack, diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 790bfab192..e90ae9d0fc 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -656,12 +656,13 @@ def trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs, debug_in @debug_logger -def lower_jaxpr_to_mlir(jaxpr, func_name): +def lower_jaxpr_to_mlir(jaxpr, func_name, arg_names): """Lower a JAXPR to MLIR. Args: ClosedJaxpr: the JAXPR to lower to MLIR func_name: a name to use for the MLIR function + arg_names: list of parameter names for the MLIR function Returns: ir.Module: the MLIR module coontaining the JAX program @@ -671,7 +672,7 @@ def lower_jaxpr_to_mlir(jaxpr, func_name): MemrefCallable.clearcache() with transient_jax_config({"jax_dynamic_shapes": True}): - mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr) + mlir_module, ctx = jaxpr_to_mlir(jaxpr, func_name, arg_names) return mlir_module, ctx diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 9c19a68a5d..dea842f16d 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -44,6 +44,7 @@ from catalyst.tracing.type_signatures import ( filter_static_args, get_abstract_signature, + get_arg_names, get_type_annotations, merge_static_argname_into_argnum, merge_static_args, @@ -79,6 +80,7 @@ def qjit( async_qnodes=False, target="binary", keep_intermediate=False, + use_nameloc=False, verbose=False, logfile=None, pipelines=None, @@ -121,6 +123,8 @@ def qjit( - :attr:`~.QJIT.mlir`: MLIR representation after canonicalization - :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization - :attr:`~.QJIT.qir`: QIR in LLVM IR form + use_nameloc (bool): If ``True``, function parameter names are added to the IR as name + locations. verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are printed out. logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default - @@ -517,7 +521,6 @@ class QJIT(CatalystCallable): :ivar jaxpr: This attribute stores the Jaxpr compiled from the function as a string. :ivar mlir: This attribute stores the MLIR compiled from the function as a string. :ivar qir: This attribute stores the QIR in LLVM IR form compiled from the function as a string. - """ @debug_logger_init @@ -562,20 +565,26 @@ def __init__(self, fn, compile_options): @property def mlir(self): - """obtain the MLIR representation after canonicalization""" + """Obtain the MLIR representation after canonicalization""" # Canonicalize the MLIR since there can be a lot of redundancy coming from JAX. if not self.mlir_module: return None - return canonicalize(stdin=str(self.mlir_module)) + stdin = self.mlir_module.operation.get_asm( + enable_debug_info=self.compile_options.use_nameloc + ) + return canonicalize(stdin=stdin, options=self.compile_options) @property def mlir_opt(self): - """obtain the MLIR representation after optimization""" + """Obtain the MLIR representation after optimization""" if not self.mlir_module: return None - return to_mlir_opt(stdin=str(self.mlir_module), options=self.compile_options) + stdin = self.mlir_module.operation.get_asm( + enable_debug_info=self.compile_options.use_nameloc + ) + return to_mlir_opt(stdin=stdin, options=self.compile_options) @debug_logger def __call__(self, *args, **kwargs): @@ -604,7 +613,6 @@ def __call__(self, *args, **kwargs): @debug_logger def aot_compile(self): """Compile Python function on initialization using the type hint signature.""" - self.workspace = self._get_workspace() # TODO: awkward, refactor or redesign the target feature @@ -643,7 +651,6 @@ def jit_compile(self, args, **kwargs): bool: whether the provided arguments will require promotion to be used with the compiled function """ - cached_fn, requires_promotion = self.fn_cache.lookup(args) if cached_fn is None: @@ -774,8 +781,9 @@ def generate_ir(self): Returns: Tuple[ir.Module, str]: the in-memory MLIR module and its string representation """ - - mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__) + mlir_module, ctx = lower_jaxpr_to_mlir( + self.jaxpr, self.__name__, get_arg_names(self.jaxpr.in_avals, self.original_function) + ) # Inject Runtime Library-specific functions (e.g. setup/teardown). inject_functions(mlir_module, ctx, self.compile_options.seed) @@ -790,7 +798,6 @@ def compile(self): Returns: Tuple[CompiledFunction, str]: the compilation result and LLVMIR """ - # WARNING: assumption is that the first function is the entry point to the compiled program. entry_point_func = self.mlir_module.body.operations[0] restype = entry_point_func.type.results @@ -833,7 +840,6 @@ def run(self, args, kwargs): Returns: Any: results of the execution arranged into the original function's output PyTrees """ - results = self.compiled_function(*args, **kwargs) # TODO: Move this to the compiled function object. @@ -853,7 +859,6 @@ def _validate_configuration(self): def _get_workspace(self): """Get or create a workspace to use for compilation.""" - workspace_name = self.__name__ preferred_workspace_dir = os.getcwd() if self.use_cwd_for_workspace else None diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 328a4d6623..d35517b4c2 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -84,6 +84,8 @@ class CompileOptions: - ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept. - ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline. - ``2`` or ``"pass"``: Intermediate files are saved after each pass. + use_nameloc (Optional[bool]): If ``True``, add function parameter names to the IR as name + locations. pipelines (Optional[List[Tuple[str,List[str]]]]): A list of tuples. The first entry of the tuple corresponds to the name of a pipeline. The second entry of the tuple corresponds to a list of MLIR passes. @@ -115,6 +117,7 @@ class CompileOptions: logfile: Optional[TextIOWrapper] = sys.stderr target: Optional[str] = "binary" keep_intermediate: Optional[Union[str, int, bool, KeepIntermediateLevel]] = False + use_nameloc: Optional[bool] = False pipelines: Optional[List[Any]] = None autograph: Optional[bool] = False autograph_include: Optional[Iterable[str]] = () diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index b4090001c4..d0ef9903a9 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -22,7 +22,7 @@ from typing import Callable import jax -from jax._src.core import shaped_abstractify +from jax._src.core import DShapedArray, shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs from jax.core import AbstractValue @@ -324,3 +324,28 @@ def promote_arguments(target_signature, args): promoted_args.append(promoted_arg) return tree_unflatten(treedef, promoted_args) + + +def get_arg_names(qjit_jaxpr_in_avals: tuple[AbstractValue, ...], qjit_original_function: Callable): + """Construct a list of argument names, with the size of qjit_jaxpr_in_avals, and fill it with + the names of the parameters of the original function signature. + The number of parameters of the original function could be different to the number of + elements in qjit_jaxpr_in_avals. For example, if a function with one parameter is invoked with a + dynamic argument, qjit_jaxpr_in_avals will contain two elements (a dynamically-shaped array, and + its type). + + Args: + qjit_jaxpr_in_avals: list of abstract values that represent the inputs to the QJIT's JAXPR + qjit_original_function: QJIT's original function + + Returns: + A list of argument names with the same number of elements than qjit_jaxpr_in_avals. + The argument names are assigned from the list of parameters of the original function, + in order, and until that list is empty. Then left to empty strings. + """ + arg_names = [""] * len(qjit_jaxpr_in_avals) + param_values = [p.name for p in inspect.signature(qjit_original_function).parameters.values()] + for in_aval_index, in_aval in enumerate(qjit_jaxpr_in_avals): + if len(param_values) > 0 and type(in_aval) != DShapedArray: + arg_names[in_aval_index] = param_values.pop(0) + return arg_names diff --git a/frontend/test/lit/test_option_use_nameloc.py b/frontend/test/lit/test_option_use_nameloc.py new file mode 100644 index 0000000000..e64f0c425d --- /dev/null +++ b/frontend/test/lit/test_option_use_nameloc.py @@ -0,0 +1,54 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for use name location option.""" + +# RUN: %PYTHON %s | FileCheck %s + +from utils import print_mlir, print_mlir_opt + +from catalyst import qjit + + +# CHECK-LABEL: @jit_f +@qjit(use_nameloc=True) +def f(x: float, y: float): + """Check that MLIR module contains name location information, and MLIR code uses that name + location information. + """ + # CHECK: %x: tensor, %y: tensor + return x * y + + +assert str(f.mlir_module.body.operations[0].arguments[0].location) == 'loc("x")' +assert str(f.mlir_module.body.operations[0].arguments[1].location) == 'loc("y")' + +print_mlir(f, 0.3, 0.4) + + +# CHECK-LABEL: @jit_f_opt +@qjit(use_nameloc=True) +def f_opt(x: float, y: float): + """Check that MLIR module contains name location information, and MLIR code uses that name + location information. + Same test as before, but now we exercise mlir_opt property. + """ + # CHECK: %x: tensor, %y: tensor + return x * y + + +assert str(f_opt.mlir_module.body.operations[0].arguments[0].location) == 'loc("x")' +assert str(f_opt.mlir_module.body.operations[0].arguments[1].location) == 'loc("y")' + +print_mlir_opt(f_opt, 0.3, 0.4) diff --git a/frontend/test/lit/utils.py b/frontend/test/lit/utils.py index 59af0f2494..778c3c3d66 100644 --- a/frontend/test/lit/utils.py +++ b/frontend/test/lit/utils.py @@ -37,3 +37,8 @@ def print_jaxpr(f, *args, **kwargs): def print_mlir(f, *args, **kwargs): """Print mlir code of a function""" return print_attr(f, "mlir", *args, **kwargs) + + +def print_mlir_opt(f, *args, **kwargs): + """Print mlir code of a function""" + return print_attr(f, "mlir_opt", *args, **kwargs) diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 93376c4e05..805b18c522 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -554,6 +554,13 @@ def test_option_dialect_plugin_tuple(self): assert ("--load-dialect-plugin", path) in flags assert isinstance(options.dialect_plugins, set) + def test_option_use_nameloc(self): + """Test use name location option""" + + options = CompileOptions(use_nameloc=True) + flags = _options_to_cli_flags(options) + assert "--use-nameloc-as-prefix" in flags + def test_option_not_lower_to_llvm(self): """Test not lower to llvm""" options = CompileOptions(lower_to_llvm=False) diff --git a/frontend/test/pytest/test_get_arg_names.py b/frontend/test/pytest/test_get_arg_names.py new file mode 100644 index 0000000000..9a78f68e5d --- /dev/null +++ b/frontend/test/pytest/test_get_arg_names.py @@ -0,0 +1,86 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for get_arg_names function.""" + +import jax.numpy as jnp +import pennylane as qml +from jax.core import ShapedArray + +from catalyst import qjit + + +@qjit +def f_of_empty(): + """Check empty list of arguments""" + return True + + +f_of_empty.jit_compile([]) +assert f_of_empty.get_arg_names() == [] + + +@qjit +def f_of_a_b(a: float, b: float): + """Check two float arguments""" + return a * b + + +f_of_a_b.jit_compile([0.3, 0.4]) +assert f_of_a_b.get_arg_names() == ["a", "b"] + + +@qjit(abstracted_axes={0: "n"}) +def f_of_dynamic_argument(a): + """Check dynamic argument""" + return a + + +f_of_dynamic_argument.jit_compile([jnp.array([1, 2, 3])]) +assert f_of_dynamic_argument.get_arg_names() == ["a", ""] + + +@qjit(abstracted_axes={0: "n"}) +def f_of_qnode_with_dynamic_argument(a): + """Check QNode argument with dynamic argument""" + + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def _circuit(b): + return b + + return _circuit(a) + + +f_of_qnode_with_dynamic_argument.jit_compile([jnp.array([1, 2, 3])]) +assert f_of_dynamic_argument.get_arg_names() == ["a", ""] + + +@qjit +def f_of_a_with_dynamic_result(a): + """Check dynamic result""" + return jnp.ones((a + 1,), dtype=float) + + +f_of_a_with_dynamic_result.jit_compile([3]) +assert f_of_a_with_dynamic_result.get_arg_names() == ["a"] + + +@qjit(abstracted_axes={0: "n", 2: "m"}) +def f_of_shaped_array(a: ShapedArray([1, 3, 1], dtype=float)): + """Check ShapedArray argument""" + return a + + +f_of_shaped_array.aot_compile() +assert f_of_shaped_array.get_arg_names() == ["a", "", ""] diff --git a/frontend/test/pytest/test_tracing.py b/frontend/test/pytest/test_tracing.py index 9b7b84a38f..17fee8d504 100644 --- a/frontend/test/pytest/test_tracing.py +++ b/frontend/test/pytest/test_tracing.py @@ -27,7 +27,7 @@ def f(): return 0 jaxpr = jax.make_jaxpr(f)() - result, _ = lower_jaxpr_to_mlir(jaxpr, "test_fn") + result, _ = lower_jaxpr_to_mlir(jaxpr, "test_fn", []) assert "@jit_test_fn() -> tensor" in str(result) diff --git a/mlir/include/Driver/CompilerDriver.h b/mlir/include/Driver/CompilerDriver.h index f81b996cfd..407e6c9922 100644 --- a/mlir/include/Driver/CompilerDriver.h +++ b/mlir/include/Driver/CompilerDriver.h @@ -63,6 +63,8 @@ struct CompilerOptions { SaveTemps keepIntermediate; /// If true, the compiler will dump the module scope when saving intermediate files. bool dumpModuleScope; + /// Print SSA IDs using their name location, if provided, as prefix. + bool useNameLocAsPrefix; /// If true, the llvm.coroutine will be lowered. bool asyncQnodes; /// Sets the verbosity level to use when printing messages. @@ -114,7 +116,8 @@ mlir::LogicalResult QuantumDriverMain(const catalyst::driver::CompilerOptions &o int QuantumDriverMainFromCL(int argc, char **argv); int QuantumDriverMainFromArgs(const std::string &source, const std::string &workspace, const std::string &moduleName, bool keepIntermediate, - bool asyncQNodes, bool verbose, bool lowerToLLVM, + bool useNamelocAsPrefix, bool asyncQNodes, bool verbose, + bool lowerToLLVM, const std::vector &passPipelines, const std::string &checkpointStage, catalyst::driver::CompilerOutput &output); diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 27bf1f31d8..6c0e6d381b 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -661,6 +661,11 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & { using timer = catalyst::utils::Timer; + OpPrintingFlags opPrintingFlags{}; + if (options.useNameLocAsPrefix) { + opPrintingFlags.printNameLocAsPrefix(); + } + MLIRContext ctx(registry); ctx.printOpOnDiagnostic(true); ctx.printStackTraceOnDiagnostic(options.verbosity >= Verbosity::Debug); @@ -738,7 +743,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & } output.outIR.clear(); if (options.keepIntermediate) { - outIRStream << *mlirModule; + mlirModule->print(outIRStream, opPrintingFlags); } optTiming.stop(); } @@ -855,7 +860,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & // already handled } else if (output.outputFilename == "-" && mlirModule) { - outfile->os() << *mlirModule; + mlirModule->print(outfile->os(), opPrintingFlags); outfile->keep(); } @@ -944,6 +949,9 @@ int QuantumDriverMainFromCL(int argc, char **argv) "keep-intermediate", cl::desc("Keep intermediate files"), cl::init(false), cl::callback([&](const bool &) { SaveAfterEach.setValue(SaveTemps::AfterPipeline); }), cl::cat(CatalystCat)); + cl::opt UseNameLocAsPrefix("use-nameloc-as-prefix", + cl::desc("Use name location as prefix"), cl::init(false), + cl::cat(CatalystCat)); cl::opt AsyncQNodes("async-qnodes", cl::desc("Enable asynchronous QNodes"), cl::init(false), cl::cat(CatalystCat)); cl::opt Verbose("verbose", cl::desc("Set verbose"), cl::init(false), @@ -1014,6 +1022,7 @@ int QuantumDriverMainFromCL(int argc, char **argv) .diagnosticStream = errStream, .keepIntermediate = SaveAfterEach, .dumpModuleScope = DumpModuleScope, + .useNameLocAsPrefix = UseNameLocAsPrefix, .asyncQnodes = AsyncQNodes, .verbosity = Verbose ? Verbosity::All : Verbosity::Urgent, .pipelinesCfg = parsePipelines(CatalystPipeline),