diff --git a/.dep-versions b/.dep-versions index d755906e6d..849b989020 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,9 +2,9 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -mhlo=617a9361d186199480c080c9e8c474a5e30c22d1 -llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158 -enzyme=v0.0.180 +mhlo=1dd2e71331014ae0373f6bf900ce6be393357190 +llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 +enzyme=v0.0.186 # Always remove custom PL/LQ versions before release. diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml index df53cd5880..ef4b3d04ce 100644 --- a/.github/workflows/build-wheel-linux-arm64.yaml +++ b/.github/workflows/build-wheel-linux-arm64.yaml @@ -108,6 +108,12 @@ jobs: ref: ${{ needs.constants.outputs.llvm_version }} path: ${{ github.workspace }}/mlir/llvm-project + - name: Patch LLVM Source + if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + run: | + cd $GITHUB_WORKSPACE/mlir/llvm-project + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone MHLO Submodule if: steps.cache-mhlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -122,6 +128,7 @@ jobs: cd $GITHUB_WORKSPACE/mlir/mlir-hlo git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -134,9 +141,8 @@ jobs: - name: Patch Enzyme Source if: steps.cache-enzyme-source.outputs.cache-hit != 'true' run: | - export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp - export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch - patch -p1 $TARGET_FILE $PATCH_FILE + cd $GITHUB_WORKSPACE/mlir/Enzyme + git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch # Cache external project builds - name: Restore LLVM Build diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml index 35be761c99..bde5f9cbbc 100644 --- a/.github/workflows/build-wheel-linux-x86_64.yaml +++ b/.github/workflows/build-wheel-linux-x86_64.yaml @@ -127,6 +127,12 @@ jobs: ref: ${{ needs.constants.outputs.llvm_version }} path: ${{ github.workspace }}/mlir/llvm-project + - name: Patch LLVM Source + if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + run: | + cd $GITHUB_WORKSPACE/mlir/llvm-project + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone MHLO Submodule if: steps.cache-mhlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -141,6 +147,7 @@ jobs: cd $GITHUB_WORKSPACE/mlir/mlir-hlo git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -153,9 +160,8 @@ jobs: - name: Patch Enzyme Source if: steps.cache-enzyme-source.outputs.cache-hit != 'true' run: | - export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp - export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch - patch -p1 $TARGET_FILE $PATCH_FILE + cd $GITHUB_WORKSPACE/mlir/Enzyme + git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch # Cache external project builds - name: Restore LLVM Build diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml index d6d2e9d215..fdc101e3a8 100644 --- a/.github/workflows/build-wheel-macos-arm64.yaml +++ b/.github/workflows/build-wheel-macos-arm64.yaml @@ -113,6 +113,12 @@ jobs: ref: ${{ needs.constants.outputs.llvm_version }} path: ${{ github.workspace }}/mlir/llvm-project + - name: Patch LLVM Source + if: steps.cache-mhlo-source.outputs.cache-hit != 'true' + run: | + cd $GITHUB_WORKSPACE/mlir/llvm-project + git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch + - name: Clone MHLO Submodule if: steps.cache-mhlo-source.outputs.cache-hit != 'true' uses: actions/checkout@v4 @@ -127,6 +133,7 @@ jobs: cd $GITHUB_WORKSPACE/mlir/mlir-hlo git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch + git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch - name: Clone Enzyme Submodule if: steps.cache-enzyme-source.outputs.cache-hit != 'true' @@ -139,9 +146,8 @@ jobs: - name: Patch Enzyme Source if: steps.cache-enzyme-source.outputs.cache-hit != 'true' run: | - export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp - export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch - patch -p1 $TARGET_FILE $PATCH_FILE + cd $GITHUB_WORKSPACE/mlir/Enzyme + git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch # Cache external project builds - name: Restore LLVM Build diff --git a/Makefile b/Makefile index e3309167f0..68d81ae519 100644 --- a/Makefile +++ b/Makefile @@ -282,6 +282,9 @@ clean-plugin: clean-llvm: $(MAKE) -C mlir clean-llvm +reset-llvm: + $(MAKE) -C mlir reset-llvm + clean-mhlo: $(MAKE) -C mlir clean-mhlo diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 32688e2dab..2c03b31a80 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -12,6 +12,16 @@ * The JAX version used by Catalyst is updated to 0.6.2. [(#1897)](https://github.com/PennyLaneAI/catalyst/pull/1897) +* The version of LLVM, mlir-hlo, and Enzyme used by Catalyst has been updated. + [(#1916)](https://github.com/PennyLaneAI/catalyst/pull/1916) + + The LLVM version has been updated to + [commit f8cb798](https://github.com/llvm/llvm-project/tree/f8cb7987c64dcffb72414a40560055cb717dbf74). + The mlir-hlo version has been updated to + [commit 1dd2e71](https://github.com/tensorflow/mlir-hlo/tree/1dd2e71331014ae0373f6bf900ce6be393357190). + The Enzyme version has been updated to + [v0.0.186](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186). +

Deprecations 👋

Bug fixes 🐛

diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index d2bfbc74bb..7dc1382593 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -16,6 +16,7 @@ from __future__ import annotations import logging +import textwrap import jax from jax._src.dispatch import jaxpr_replicas @@ -38,6 +39,7 @@ import catalyst from catalyst.logging import debug_logger +from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher # pylint: disable=protected-access @@ -165,3 +167,57 @@ def custom_lower_jaxpr_to_module( worklist += [*op.body.operations] return ctx.module, ctx.context + + +def get_mlir_attribute_from_pyval(value): + """ + Given a value of any type, construct an mlir attribute of corresponding type. + + We set up the context and location outside because recursive calls to this function + will segfault if multiple `Context()`s are instantiated. + """ + + attr = None + match value: + case bool(): + attr = ir.BoolAttr.get(value) + + case int(): + if -9223372036854775808 <= value < 0: # 2**63 + attr = ir.IntegerAttr.get(ir.IntegerType.get_signed(64), value) + elif 0 <= value < 18446744073709551616: # = 2**64 + attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + else: + raise CompileError( + textwrap.dedent( + """ + Large interger attributes currently not supported in MLIR, + see https://github.com/llvm/llvm-project/issues/128072 + """ + ) + ) + + case float(): + attr = ir.FloatAttr.get(ir.F64Type.get(), value) + + case str(): + attr = ir.StringAttr.get(value) + + case list() | tuple(): + element_attrs = [get_mlir_attribute_from_pyval(elem) for elem in value] + attr = ir.ArrayAttr.get(element_attrs) + + case dict(): + named_attrs = {} + for k, v in value.items(): + if not isinstance(k, str): + raise CompileError( + f"Dictionary keys for MLIR DictionaryAttr must be strings, got: {type(k)}" + ) + named_attrs[k] = get_mlir_attribute_from_pyval(v) + attr = ir.DictAttr.get(named_attrs) + + case _: + raise CompileError(f"Cannot convert Python type {type(value)} to an MLIR attribute.") + + return attr diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index cd2d27eb53..effa498b45 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -280,7 +280,11 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin for _pass in pipeline: options = _pass.get_options() apply_registered_pass_op = ApplyRegisteredPassOp( - result=transform_mod_type, target=target, pass_name=_pass.name, options=options + result=transform_mod_type, + target=target, + pass_name=_pass.name, + options=options, + dynamic_options={}, ) target = apply_registered_pass_op.result transform_yield_op = YieldOp(operands_=[]) # pylint: disable=unused-variable diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py index 423f8c5269..b0ed5db380 100644 --- a/frontend/catalyst/passes/pass_api.py +++ b/frontend/catalyst/passes/pass_api.py @@ -19,6 +19,7 @@ import pennylane as qml +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval from catalyst.tracing.contexts import EvaluationContext PipelineDict: TypeAlias = dict[str, dict[str, str]] @@ -286,23 +287,18 @@ def __init__(self, name: str, *options: list[str], **valued_options: dict[str, s def get_options(self): """ - Stringify options according to what mlir-opt expects. - - ApplyRegisteredPassOp expects options to be a single StringAttr - which follows the same format as the one used with mlir-opt. - - https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop - - Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...}, - i.e., use space-separated key=value pairs for each option. + Build a dictionary mapping option names to MLIR attributes. + ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes. + See https://github.com/llvm/llvm-project/pull/143159 + """ + options_dict = {} + for option in self.options: + options_dict[str(option)] = get_mlir_attribute_from_pyval(True) - https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options + for option, value in self.valued_options.items(): + options_dict[str(option)] = get_mlir_attribute_from_pyval(value) - Experimentally we found that single-options also work without values. - """ - retval = " ".join(f"{str(option)}" for option in self.options) - retval2 = " ".join(f"{str(key)}={str(value)}" for key, value in self.valued_options.items()) - return " ".join([retval, retval2]).strip() + return options_dict def __repr__(self): return ( diff --git a/frontend/test/lit/test_mlir_plugin.py b/frontend/test/lit/test_mlir_plugin.py index 0be8d5cbb1..09d519c7ef 100644 --- a/frontend/test/lit/test_mlir_plugin.py +++ b/frontend/test/lit/test_mlir_plugin.py @@ -106,7 +106,7 @@ def test_pass_options(): """Is the option in the generated MLIR?""" @qjit(target="mlir") - # CHECK: options = "an-option maxValue=1" + # CHECK: options = {"an-option" = true, "maxValue" = 1 : i64} @catalyst.passes.apply_pass("some-pass", "an-option", maxValue=1) @qml.qnode(qml.device("null.qubit", wires=1)) def example(): diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py index 99b78f29b6..37e2d0762a 100644 --- a/frontend/test/pytest/test_jax_integration.py +++ b/frontend/test/pytest/test_jax_integration.py @@ -14,15 +14,19 @@ """Test QJIT compatibility with JAX transformations such as jax.jit and jax.grad.""" +import textwrap from functools import partial import jax import jax.numpy as jnp import pennylane as qml import pytest +from jax.interpreters.mlir import ir from catalyst import for_loop, measure, qjit +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval from catalyst.jit import JAX_QJIT +from catalyst.utils.exceptions import CompileError class TestJAXJIT: @@ -490,5 +494,122 @@ def ansatz(i, x): jax.grad(circuit, argnums=0)(params, 3) +ctx = ir.Context() +loc = ir.Location.unknown(ctx) + + +class TestJAXMLIRAttributeGetter: + """ + Test catalyst.jax_extras.lowering.get_mlir_attribute_from_pyval + """ + + def test_bool_attr(self): + """ + Test bool attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval(True) + assert isinstance(attr, ir.BoolAttr) + assert attr.value == True + + def test_str_attr(self): + """ + Test string attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval("hello catalyst!") + assert isinstance(attr, ir.StringAttr) + assert attr.value == "hello catalyst!" + + @pytest.mark.parametrize("number", (37, -37)) + def test_int_attr(self, number): + """ + Test integer attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval(number) + assert isinstance(attr, ir.IntegerAttr) + assert attr.value == number + + @pytest.mark.parametrize("number", (3.7, -3.7)) + def test_float_attr(self, number): + """ + Test float attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval(number) + assert isinstance(attr, ir.FloatAttr) + assert attr.value == number + + @pytest.mark.parametrize("array", ([1, 2, 3], (4, 5, 6))) + def test_array_attr(self, array): + """ + Test array attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval(array) + assert isinstance(attr, ir.ArrayAttr) + assert len(attr) == len(array) + + for attr_val, py_val in zip(attr, array): + assert isinstance(attr_val, ir.IntegerAttr) + assert attr_val.value == py_val + + def test_dict_attr(self): + """ + Test dictionary attribute. + """ + with ctx, loc: + attr = get_mlir_attribute_from_pyval( + {"device": "lightning.qubit", "wire_capacity": 100} + ) + assert isinstance(attr, ir.DictAttr) + + assert isinstance(attr["device"], ir.StringAttr) + assert attr["device"].value == "lightning.qubit" + + assert isinstance(attr["wire_capacity"], ir.IntegerAttr) + assert attr["wire_capacity"].value == 100 + + def test_dict_attr_with_bad_keys(self): + """ + Test dictionary attribute with non-string keys. + """ + with pytest.raises( + CompileError, match="Dictionary keys for MLIR DictionaryAttr must be strings" + ): + with ctx, loc: + _ = get_mlir_attribute_from_pyval({37: 42}) + + def test_bad_type(self): + """ + Test an error is correctly raised on a python type not convertible to mlir attribute. + """ + + # pylint: disable=missing-class-docstring + class Foo: + pass + + with pytest.raises(CompileError, match="Cannot convert Python type"): + with ctx, loc: + _ = get_mlir_attribute_from_pyval(Foo()) + + def test_int_attr_overflow(self): + """ + Test int attribute with overflow correctly raises error. + """ + with pytest.raises( + CompileError, + match=textwrap.dedent( + """ + Large interger attributes currently not supported in MLIR, + see https://github.com/llvm/llvm-project/issues/128072 + """ + ), + ): + with ctx, loc: + _ = get_mlir_attribute_from_pyval(2**100) + + if __name__ == "__main__": pytest.main(["-x", __file__]) diff --git a/frontend/test/pytest/test_mlir_plugin_interface.py b/frontend/test/pytest/test_mlir_plugin_interface.py index 30f7c040b9..afde4c920f 100644 --- a/frontend/test/pytest/test_mlir_plugin_interface.py +++ b/frontend/test/pytest/test_mlir_plugin_interface.py @@ -19,6 +19,7 @@ import pennylane as qml import pytest +from jax.interpreters.mlir import ir import catalyst from catalyst import qjit @@ -73,24 +74,26 @@ def test_get_options(): """ Test get_options from Pass - ApplyRegisteredPassOp expects options to be a single StringAttr - which follows the same format as the one used with mlir-opt. - - https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop - - Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...}, - i.e., use space-separated key=value pairs for each option. - - https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options - - However, experimentally we found that single-options also work without values. + ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes. + See https://github.com/llvm/llvm-project/pull/143159 """ - assert catalyst.passes.Pass("example-pass", "single-option").get_options() == "single-option" - assert ( - catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options() - == "an-option bn-option" - ) - assert catalyst.passes.Pass("example-pass", option=True).get_options() == "option=True" + with ir.Context(), ir.Location.unknown(): + options = catalyst.passes.Pass("example-pass", "single-option").get_options() + assert isinstance(options, dict) + assert isinstance(options["single-option"], ir.BoolAttr) + assert options["single-option"].value == True + + options = catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options() + assert isinstance(options, dict) + assert isinstance(options["an-option"], ir.BoolAttr) + assert options["an-option"].value == True + assert isinstance(options["bn-option"], ir.BoolAttr) + assert options["bn-option"].value == True + + options = catalyst.passes.Pass("example-pass", option=True).get_options() + assert isinstance(options, dict) + assert isinstance(options["option"], ir.BoolAttr) + assert options["option"].value == True @pytest.mark.skip(reason="xdsl not installed in ci cd yet") diff --git a/mlir/Enzyme b/mlir/Enzyme index db0181320d..8c1a596158 160000 --- a/mlir/Enzyme +++ b/mlir/Enzyme @@ -1 +1 @@ -Subproject commit db0181320d6e425ee963bd496ed0d8dbb615be18 +Subproject commit 8c1a596158f6194f10e8ffd56a1660a61c54337e diff --git a/mlir/Makefile b/mlir/Makefile index 5b3dc53f53..0a41527370 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -57,6 +57,14 @@ all: llvm mhlo enzyme dialects plugin .PHONY: llvm llvm: @echo "build LLVM and MLIR enabling Python bindings" + + # Patch mlir one shot bufferization segfault + # Remove patch after bug is resolved upstream + # https://github.com/llvm/llvm-project/issues/150441 + @if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-bufferization-segfault.patch; then \ + git apply $(MK_DIR)/patches/llvm-bufferization-segfault.patch; \ + fi + cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -97,6 +105,11 @@ mhlo: @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \ git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \ fi + + # Patch a MHLO bug with std::sort + @if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-rename-sort.patch; then \ + git apply $(MK_DIR)/patches/mhlo-rename-sort.patch; \ + fi cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_ENABLE_ASSERTIONS=ON \ @@ -121,8 +134,8 @@ enzyme: PATCH_FILE := $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch enzyme: @echo "build enzyme" # Patch enzyme's dependency on nvidia fabs llvm intrinsics - @if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \ - patch -p1 $(TARGET_FILE) $(PATCH_FILE); \ + @if cd Enzyme; git apply --check $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; then \ + git apply $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; \ fi cmake -G Ninja -S Enzyme/enzyme -B $(ENZYME_BUILD_DIR) \ -DENZYME_STATIC_LIB=ON \ @@ -204,6 +217,11 @@ clean-dialects: clean-llvm: @echo "clean llvm/mlir build files" rm -rf $(LLVM_BUILD_DIR) + cd llvm-project; git clean -fd; git checkout . + +reset-llvm: + @echo "reset llvm git state to the commit tracked in .dep-versions without deleting llvm builds" + cd llvm-project; git clean -fd; git checkout . clean-mhlo: @echo "clean HLO dialect build files" @@ -213,6 +231,7 @@ clean-mhlo: clean-enzyme: @echo "clean enzyme build files" rm -rf $(ENZYME_BUILD_DIR) + cd Enzyme; git clean -fd; git checkout . clean-plugin: @echo "clean plugin" diff --git a/mlir/include/Catalyst/Transforms/AsyncUtils.h b/mlir/include/Catalyst/Transforms/AsyncUtils.h index 7be6815235..98df1f3972 100644 --- a/mlir/include/Catalyst/Transforms/AsyncUtils.h +++ b/mlir/include/Catalyst/Transforms/AsyncUtils.h @@ -63,13 +63,13 @@ bool hasAbortInBlock(Block *block); bool hasPutsInBlock(Block *block); // Helper function for creating function declarations -LLVM::LLVMFuncOp lookupOrCreatePersonality(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreateAbort(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(ModuleOp moduleOp); -LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(ModuleOp); -LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(ModuleOp); -LLVM::LLVMFuncOp lookupOrCreateDropRef(ModuleOp); +LLVM::LLVMFuncOp lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b, ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b, ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp); +LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp); +LLVM::LLVMFuncOp lookupOrCreateDropRef(OpBuilder &b, ModuleOp); }; // namespace AsyncUtils diff --git a/mlir/include/Gradient/Utils/DestinationPassingStyle.h b/mlir/include/Gradient/Utils/DestinationPassingStyle.h index 920cfc45f3..7a6ae53415 100644 --- a/mlir/include/Gradient/Utils/DestinationPassingStyle.h +++ b/mlir/include/Gradient/Utils/DestinationPassingStyle.h @@ -17,5 +17,6 @@ namespace catalyst { /// Convert every MemRef-typed return value in `callee` to writing to a new argument in /// destination-passing style. -void convertToDestinationPassingStyle(mlir::func::FuncOp callee, mlir::OpBuilder &builder); +llvm::LogicalResult convertToDestinationPassingStyle(mlir::func::FuncOp callee, + mlir::OpBuilder &builder); } // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp index 76cee871d4..ad3ec61a72 100644 --- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp +++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp @@ -128,81 +128,84 @@ LLVM::LLVMFuncOp AsyncUtils::getCaller(LLVM::CallOp callOp) return callOp->getParentOfType(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); auto i32Ty = IntegerType::get(ctx, 32); bool isVarArg = true; - return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::personalityName, {}, i32Ty, - isVarArg) + return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::personalityName, {}, + i32Ty, isVarArg) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); auto voidTy = LLVM::LLVMVoidType::get(ctx); - return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy) + return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::abortName, {}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy) + b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy) + b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); Type llvmInt64Type = IntegerType::get(moduleOp.getContext(), 64); auto voidTy = LLVM::LLVMVoidType::get(ctx); - return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeDropRefName, + return mlir::LLVM::lookupOrCreateFn(b, moduleOp, + AsyncUtilsConstants::mlirAsyncRuntimeDropRefName, {ptrTy, llvmInt64Type}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b, + ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy) + b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b, + ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext()); auto voidTy = LLVM::LLVMVoidType::get(ctx); return mlir::LLVM::lookupOrCreateFn( - moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy) + b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy) .value(); } -LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp) +LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp) { MLIRContext *ctx = moduleOp.getContext(); auto voidTy = LLVM::LLVMVoidType::get(ctx); - return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::unrecoverableErrorName, {}, - voidTy) + return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::unrecoverableErrorName, + {}, voidTy) .value(); } diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index c5b93d1602..ebd559a188 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -56,11 +56,12 @@ struct PrintOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto printOp = cast(op); if (printOp.getVal()) { - FailureOr source = getBuffer(rewriter, printOp.getVal(), options); + FailureOr source = getBuffer(rewriter, printOp.getVal(), options, state); if (failed(source)) { return failure(); } @@ -116,7 +117,8 @@ struct CustomCallOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto customCallOp = cast(op); @@ -124,7 +126,7 @@ struct CustomCallOpInterface SmallVector bufferArgs; ValueRange operands = customCallOp.getOperands(); for (Value operand : operands) { - FailureOr opBuffer = getBuffer(rewriter, operand, options); + FailureOr opBuffer = getBuffer(rewriter, operand, options, state); if (failed(opBuffer)) { return failure(); } @@ -165,11 +167,11 @@ struct CustomCallOpInterface } auto options = bufferization::BufferizationOptions(); FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( - rewriter, op->getLoc(), result, options, false); + rewriter, op->getLoc(), result, options, state, false); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); auto newBuffer = - rewriter.create(op->getLoc(), memrefType, *tensorAlloc); + rewriter.create(op->getLoc(), memrefType, *tensorAlloc); bufferArgs.push_back(newBuffer); } @@ -207,7 +209,8 @@ struct CallbackOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto callbackOp = cast(op); @@ -279,7 +282,8 @@ struct CallbackCallOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto callOp = cast(op); @@ -292,7 +296,7 @@ struct CallbackCallOpInterface SmallVector newInputs; auto operands = callOp.getOperands(); for (Value operand : operands) { - FailureOr opBuffer = getBuffer(rewriter, operand, options); + FailureOr opBuffer = getBuffer(rewriter, operand, options, state); if (failed(opBuffer)) { return failure(); } @@ -303,8 +307,8 @@ struct CallbackCallOpInterface auto loc = callOp->getLoc(); SmallVector outmemrefs; for (auto result : results) { - FailureOr tensorAlloc = - bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); + FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( + rewriter, loc, result, options, state, false); if (failed(tensorAlloc)) { return failure(); } @@ -314,8 +318,8 @@ struct CallbackCallOpInterface auto shape = tensorTy.getShape(); auto elementTy = tensorTy.getElementType(); auto memrefType = MemRefType::get(shape, elementTy); - auto toMemrefOp = rewriter.create(loc, memrefType, tensor); - auto memref = toMemrefOp.getResult(); + auto toBufferOp = rewriter.create(loc, memrefType, tensor); + auto memref = toBufferOp.getResult(); outmemrefs.push_back(memref); newInputs.push_back(memref); } diff --git a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp index d516605b4c..160da0d118 100644 --- a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp +++ b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp @@ -133,7 +133,7 @@ LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp auto moduleOp = callOp->getParentOfType(); // Here, we are adding a reference to the personality declaration. // From the documentation: https://llvm.org/docs/ExceptionHandling.html#exception-tables - auto personality = AsyncUtils::lookupOrCreatePersonality(moduleOp); + auto personality = AsyncUtils::lookupOrCreatePersonality(rewriter, moduleOp); // We annotate the body of the function containing the callop to have a reference // to the personality. @@ -294,7 +294,7 @@ RemoveAbortAndPutsInsertCallTransform::matchAndRewrite(LLVM::CallOp callOp, // Here, we are declaring an external function which is available in the Catalyst runtime. // llvm.func @__catalyst__host__rt__unrecoverable_error() auto moduleOp = callOp->getParentOfType(); - auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(moduleOp); + auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(rewriter, moduleOp); auto callee = maybeCallee.value(); rewriter.modifyOpInPlace(callee, [&] { callee.setLinkage(LLVM::Linkage::Internal); }); @@ -516,8 +516,8 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink, // llvm.func @mlirAsyncRuntimeAwaitValue(!llvm.ptr) // llvm.func @mlirAsyncRuntimeAwaitToken(!llvm.ptr) // llvm.func @mlirAsyncRuntimeDropRef(!llvm.ptr, i64) - auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(moduleOp); - auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(moduleOp); + auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(rewriter, moduleOp); + auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(rewriter, moduleOp); Type llvmInt64Type = IntegerType::get(sink->getContext(), 64); auto one = rewriter.getIntegerAttr(llvmInt64Type, 1); @@ -871,9 +871,9 @@ void insertErrorCalls(std::vector tokens, std::vector values, Bloc auto moduleOp = landingPad->getParentOfType(); LLVM::LLVMFuncOp setTokenError = - AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(moduleOp); + AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(rewriter, moduleOp); LLVM::LLVMFuncOp setValueError = - AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(moduleOp); + AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(rewriter, moduleOp); for (auto token : tokens) { insertCallToMlirAsyncRuntimeErrorFunction(token, setTokenError, failBlock, rewriter); } @@ -918,11 +918,8 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context); GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - // TODO: Update to the following lines the next time we update llvm - // config.setStrictness(GreedyRewriteStrictness::ExistingOps); - // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + config.setStrictness(GreedyRewriteStrictness::ExistingOps); + config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns1), config))) { signalPassFailure(); diff --git a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp index a3a6af1139..773621f132 100644 --- a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp +++ b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp @@ -29,11 +29,12 @@ struct GEPOpRewritePattern : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { auto defOp = op.getBase().getDefiningOp(); - if (op.getInbounds() || (defOp && isa(defOp))) { + if (op.getNoWrapFlags() == LLVM::GEPNoWrapFlags::inbounds || + (defOp && isa(defOp))) { return failure(); } rewriter.startOpModification(op); - op.setInbounds(true); + op.setNoWrapFlags(LLVM::GEPNoWrapFlags::inbounds); rewriter.finalizeOpModification(op); return success(); } diff --git a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp index f540222b77..f115efb171 100644 --- a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp +++ b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp @@ -380,11 +380,8 @@ struct AnnotateWithFullyQualifiedNamePass { MLIRContext *context = &getContext(); GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - // TODO: Update to the following lines the next time we update llvm - // config.setStrictness(GreedyRewriteStrictness::ExistingOps); - // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + config.setStrictness(GreedyRewriteStrictness::ExistingOps); + config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); RewritePatternSet annotate(context); auto root = getOperation(); @@ -409,11 +406,8 @@ struct InlineNestedSymbolTablePass : PassWrapper(op); }) != backwardSlice.end(); @@ -132,8 +133,8 @@ struct MemrefLoadTBAARewritePattern : public ConvertOpToLLVMPattern( loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false, loadOp.getNontemporal()); @@ -170,8 +171,8 @@ struct MemrefStoreTBAARewritePattern : public ConvertOpToLLVMPattern(storeOp, adaptor.getValue(), dataPtr, 0, false, storeOp.getNontemporal()); diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index d427d49386..5d884162fc 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -51,7 +51,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe } return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()), type, rewriter.create(loc, glb), - ArrayRef{0, 0}, true); + ArrayRef{0, 0}, + LLVM::GEPNoWrapFlags::inbounds); } enum NumericType : int8_t { @@ -309,7 +310,8 @@ Value EncodeDataMemRef(Location loc, PatternRewriter &rewriter, MemRefType memre MemRefDescriptor desc = MemRefDescriptor(memrefLlvm); Value c0 = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value data = rewriter.create(loc, ptr, memrefType.getElementType(), - desc.alignedPtr(rewriter, loc), c0, true); + desc.alignedPtr(rewriter, loc), c0, + LLVM::GEPNoWrapFlags::inbounds); memref = rewriter.create(loc, memref, data, 1); // Dtype @@ -335,7 +337,8 @@ struct CustomCallOpPattern : public OpConversionPattern { rewriter.setInsertionPointToStart(mod.getBody()); LLVM::LLVMFuncOp customCallFnOp = - mlir::LLVM::lookupOrCreateFn(mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr}, + mlir::LLVM::lookupOrCreateFn(rewriter, mod, op.getCallTargetName(), + {/*args=*/ptr, /*rets=*/ptr}, /*ret_type=*/voidType) .value(); customCallFnOp.setPrivate(); @@ -467,7 +470,7 @@ struct DefineCallbackOpPattern : public OpConversionPattern { ModuleOp mod = op->getParentOfType(); auto typeConverter = getTypeConverter(); LLVM::LLVMFuncOp customCallFnOp = - mlir::LLVM::lookupOrCreateFn(mod, "__catalyst_inactive_callback", + mlir::LLVM::lookupOrCreateFn(rewriter, mod, "__catalyst_inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg) .value(); diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index d1980a658e..54cbc85aac 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -177,7 +177,8 @@ struct AdjointOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto adjointOp = cast(op); Location loc = adjointOp.getLoc(); @@ -207,7 +208,7 @@ struct AdjointOpInterface ValueRange operands = adjointOp.getArgs(); for (Value operand : operands) { if (isa(operand.getType())) { - FailureOr opBuffer = getBuffer(rewriter, operand, options); + FailureOr opBuffer = getBuffer(rewriter, operand, options, state); if (failed(opBuffer)) { return failure(); } @@ -276,7 +277,8 @@ struct BackpropOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto backpropOp = cast(op); Location loc = backpropOp.getLoc(); @@ -293,7 +295,7 @@ struct BackpropOpInterface ValueRange operands = backpropOp.getArgs(); for (Value operand : operands) { if (isa(operand.getType())) { - FailureOr opBuffer = getBuffer(rewriter, operand, options); + FailureOr opBuffer = getBuffer(rewriter, operand, options, state); if (failed(opBuffer)) { return failure(); } @@ -333,7 +335,7 @@ struct BackpropOpInterface ValueRange cotangents = backpropOp.getCotangents(); SmallVector bufferCotangents; for (Value operand : cotangents) { - FailureOr opBuffer = getBuffer(rewriter, operand, options); + FailureOr opBuffer = getBuffer(rewriter, operand, options, state); if (failed(opBuffer)) { return failure(); } @@ -402,6 +404,7 @@ struct ForwardOpInterface FailureOr getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, SmallVector &invocationStack) const { // The getBufferType() method is called on either BlockArguments or OpResults. @@ -426,11 +429,12 @@ struct ForwardOpInterface return getBufferizedFunctionArgType(forwardOp, bbArg.getArgNumber(), options); } - return bufferization::detail::defaultGetBufferType(value, options, invocationStack); + return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto forwardOp = cast(op); FunctionType funcType = forwardOp.getFunctionType(); @@ -451,7 +455,7 @@ struct ForwardOpInterface // 1. Bufferize every block. for (Block &block : forwardOp.getBody()) { - if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) { + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) { return failure(); } } @@ -471,11 +475,12 @@ struct ForwardOpInterface // Note: If `inferFunctionResultLayout = true`, cast are later folded // away. - BaseMemRefType resultType = options.unknownTypeConverterFn( - returnVal, *options.defaultMemorySpaceFn(tensorType), options); - Value toMemrefOp = - rewriter.create(loc, resultType, returnVal); - returnValues.push_back(toMemrefOp); + BaseMemRefType resultType = + options.unknownTypeConverterFn(cast(returnVal.getType()), + *options.defaultMemorySpaceFn(tensorType), options); + Value toBufferOp = + rewriter.create(loc, resultType, returnVal); + returnValues.push_back(toBufferOp); } // 3. Rewrite the terminator. @@ -523,6 +528,7 @@ struct ReverseOpInterface FailureOr getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, SmallVector &invocationStack) const { // See comment on the getBufferType() method on forward op. @@ -534,11 +540,12 @@ struct ReverseOpInterface return getBufferizedFunctionArgType(reverseOp, bbArg.getArgNumber(), options); } - return bufferization::detail::defaultGetBufferType(value, options, invocationStack); + return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto reverseOp = cast(op); FunctionType funcType = reverseOp.getFunctionType(); @@ -559,7 +566,7 @@ struct ReverseOpInterface // 1. Bufferize every block. for (Block &block : reverseOp.getBody()) - if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) return failure(); // 2. For each result, keep track of which inplace argument it reuses. @@ -577,11 +584,12 @@ struct ReverseOpInterface // Note: If `inferFunctionResultLayout = true`, cast are later folded // away. - BaseMemRefType resultType = options.unknownTypeConverterFn( - returnVal, *options.defaultMemorySpaceFn(tensorType), options); - Value toMemrefOp = - rewriter.create(loc, resultType, returnVal); - returnValues.push_back(toMemrefOp); + BaseMemRefType resultType = + options.unknownTypeConverterFn(cast(returnVal.getType()), + *options.defaultMemorySpaceFn(tensorType), options); + Value toBufferOp = + rewriter.create(loc, resultType, returnVal); + returnValues.push_back(toBufferOp); } // 3. Rewrite the terminator. diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp index e3a472945e..4e2c4f188e 100644 --- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp @@ -49,8 +49,8 @@ using namespace catalyst::gradient; namespace catalyst { namespace gradient { -void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter, - RewriterBase &rewriter, Location loc, bool volatileArgs = false) +LogicalResult wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter, + RewriterBase &rewriter, Location loc, bool volatileArgs = false) { MLIRContext *ctx = rewriter.getContext(); auto ptrType = LLVM::LLVMPointerType::get(ctx); @@ -59,7 +59,9 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter, for (const auto [idx, argType] : llvm::enumerate(func.getArgumentTypes())) { if (auto memrefType = dyn_cast(argType)) { BlockArgument memrefArg = func.getArgument(idx); - func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc); + if (failed(func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc))) { + return failure(); + } Value wrappedMemref = func.getArgument(idx); Type structType = typeConverter->convertType(memrefType); @@ -78,9 +80,12 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter, rewriter.create(loc, argType, replacedMemref) .getResult(0); memrefArg.replaceAllUsesWith(replacedMemref); - func.eraseArgument(memrefArg.getArgNumber()); + if (failed(func.eraseArgument(memrefArg.getArgNumber()))) { + return failure(); + } } } + return success(); } void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConverter, @@ -171,16 +176,19 @@ LLVM::GlobalOp insertEnzymeCustomGradient(OpBuilder &builder, ModuleOp moduleOp, /// functions where MemRefs are passed via wrapped pointers (!llvm.ptr) /// rather than having their fields unpacked. This function automatically transforms MemRef /// arguments of a function to wrapped pointers. -void wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter, RewriterBase &rewriter, - Location loc, bool volatileArgs = false) +LogicalResult wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter, + RewriterBase &rewriter, Location loc, bool volatileArgs = false) { if (llvm::none_of(func.getArgumentTypes(), [](Type argType) { return isa(argType); })) { // The memref arguments are already wrapped - return; + return success(); + } + if (failed(wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs))) { + return failure(); } - wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs); wrapMemRefArgsCallsites(func, typeConverter, rewriter, loc, volatileArgs); + return success(); } } // namespace gradient } // namespace catalyst @@ -290,7 +298,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); assert(callee && "Expected a valid callee of type func.func"); - catalyst::convertToDestinationPassingStyle(callee, rewriter); + if (failed(catalyst::convertToDestinationPassingStyle(callee, rewriter))) { + return failure(); + } SymbolTableCollection symbolTable; catalyst::traverseCallGraph(callee, &symbolTable, [&](func::FuncOp func) { // Register custom gradients of quantum functions @@ -304,9 +314,11 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { if (!func->hasAttr("unwrapped_type")) { func->setAttr("unwrapped_type", TypeAttr::get(func.getFunctionType())); } - catalyst::convertToDestinationPassingStyle(func, rewriter); - - wrapMemRefArgs(func, getTypeConverter(), rewriter, loc, /*volatileArgs=*/true); + LogicalResult dpsr = catalyst::convertToDestinationPassingStyle(func, rewriter); + assert(dpsr.succeeded() && "failed to rewrite backpropOp to destination style"); + LogicalResult wmar = wrapMemRefArgs(func, getTypeConverter(), rewriter, loc, + /*volatileArgs=*/true); + assert(wmar.succeeded() && "failed to wrap backpropOp's memref args"); func::FuncOp augFwd = genAugmentedForward(func, rewriter); func::FuncOp customQGrad = @@ -318,10 +330,10 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern { LowerToLLVMOptions options = getTypeConverter()->getOptions(); if (options.useGenericFunctions) { - LLVM::LLVMFuncOp allocFn = - LLVM::lookupOrCreateGenericAllocFn(moduleOp, getTypeConverter()->getIndexType()) - .value(); - LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(moduleOp).value(); + LLVM::LLVMFuncOp allocFn = LLVM::lookupOrCreateGenericAllocFn( + rewriter, moduleOp, getTypeConverter()->getIndexType()) + .value(); + LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(rewriter, moduleOp).value(); // Register the previous functions as llvm globals (for Enzyme) // With the following piece of metadata, shadow memory is allocated with @@ -862,7 +874,10 @@ struct ForwardOpPattern : public ConvertOpToLLVMPattern { func->setAttr("passthrough", ArrayAttr::get(ctx, passthrough)); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); - catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc()); + if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, + op.getLoc()))) { + return failure(); + } rewriter.eraseOp(op); return success(); } @@ -884,16 +899,12 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { auto params = op.getArguments(); for (size_t i = 0; i < argc * 2; i++) { - bool isDup = (i % 2) != 0; - Value val = params[i]; - isDup ? differentials.push_back(val) : inputs.push_back(val); + fillValueAndShadowWithDedup(i, params, differentials, inputs); } auto upperLimit = (argc * 2) + (resc * 2); for (size_t i = argc * 2; i < upperLimit; i++) { - bool isDup = (i % 2) != 0; - Value val = params[i]; - isDup ? cotangents.push_back(val) : outputs.push_back(val); + fillValueAndShadowWithDedup(i, params, cotangents, outputs); } auto tapeCount = op.getTape(); @@ -903,16 +914,7 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { } SmallVector newFuncInputTys; - - for (auto [in, diff] : llvm::zip(inputs, differentials)) { - newFuncInputTys.push_back(in.getType()); - newFuncInputTys.push_back(diff.getType()); - } - - for (auto [out, cotan] : llvm::zip(outputs, cotangents)) { - newFuncInputTys.push_back(out.getType()); - newFuncInputTys.push_back(cotan.getType()); - } + getNewFuncInputTys(inputs, outputs, differentials, cotangents, newFuncInputTys); SmallVector tapeStructs; auto converter = getTypeConverter(); @@ -986,11 +988,40 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern { Block &firstBlock = func.getRegion().getBlocks().front(); Block &lastBlock = func.getRegion().getBlocks().back(); rewriter.mergeBlocks(&lastBlock, &firstBlock); - catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc()); + if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, + op.getLoc()))) { + return failure(); + } rewriter.eraseOp(op); return success(); } + + private: + static void getNewFuncInputTys(const SmallVector &inputs, + const SmallVector &outputs, + const SmallVector &differentials, + const SmallVector &cotangents, + SmallVector &newFuncInputTys) + { + for (auto [in, diff] : llvm::zip(inputs, differentials)) { + newFuncInputTys.push_back(in.getType()); + newFuncInputTys.push_back(diff.getType()); + } + + for (auto [out, cotan] : llvm::zip(outputs, cotangents)) { + newFuncInputTys.push_back(out.getType()); + newFuncInputTys.push_back(cotan.getType()); + } + } + + static void fillValueAndShadowWithDedup(size_t i, ValueRange params, SmallVector &values, + SmallVector &shadows) + { + bool isDup = (i % 2) != 0; + Value val = params[i]; + isDup ? shadows.push_back(val) : values.push_back(val); + } }; struct ReturnOpPattern : public ConvertOpToLLVMPattern { diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp index 464ab29089..5c24252a1e 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp @@ -148,7 +148,8 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(&splitFn.getBody().front()); Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount); - Value paramsTensor = rewriter.create(loc, paramsBuffer, true); + Value paramsTensor = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(paramsBuffer.getType()), paramsBuffer, true); qnodeQuantumArgs.push_back(paramsTensor); MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType()); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp index 66fbcee25d..853ecd1e82 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Gradient/Utils/DifferentialQNode.h" @@ -163,13 +164,14 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l auto tensorTy = diffArg.getType(); auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout( cast(tensorTy)); - auto toMemrefOp = - rewriter.create(loc, memrefTy, diffArg); + auto toBufferOp = + rewriter.create(loc, memrefTy, diffArg); - auto cloneOp = rewriter.create(loc, toMemrefOp); + auto cloneOp = rewriter.create(loc, toBufferOp); - auto toTensorOp = - rewriter.create(loc, cloneOp, true); + auto toTensorOp = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(cloneOp.getOutput().getType()), + cloneOp, true); auto diffArgCopy = toTensorOp.getResult(); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp index 5db5c4a149..75ad0a61a0 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp @@ -86,9 +86,9 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va : activeResultType); Value zero = builder.create( - loc, APFloat(elementType.getFloatSemantics(), 0), elementType); - Value one = builder.create( - loc, APFloat(elementType.getFloatSemantics(), 1), elementType); + loc, elementType, APFloat(elementType.getFloatSemantics(), 0)); + Value one = builder.create(loc, elementType, + APFloat(elementType.getFloatSemantics(), 1)); Value zeroTensor; if (auto activeResultTensor = dyn_cast(activeResultType)) { @@ -397,7 +397,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc, } else { jacobians.push_back(rewriter.create( - loc, APFloat(0.0), cast(jacobianType))); + loc, cast(jacobianType), APFloat(0.0))); } } diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp index ce91f8ec80..7227a3d35a 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp @@ -35,7 +35,7 @@ static Value genSelectiveShift(PatternRewriter &rewriter, Location loc, Value pa } // Make sure all active iteration variables match the selectors. - Value shiftCondition = rewriter.create(loc, true, 1); + Value shiftCondition = rewriter.create(loc, 1, true); for (auto &[iteration, selector] : selectors) { Value iterationMatch = rewriter.create(loc, arith::CmpIPredicate::eq, iteration, selector); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp index ff6d172908..d5cb0c117e 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp @@ -59,7 +59,8 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo { constexpr double shift = llvm::numbers::pi / 2; ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type()); - Value selectorVector = rewriter.create(loc, selectorBuffer, true); + Value selectorVector = rewriter.create( + loc, memref::getTensorTypeFromMemRefType(selectorBuffer.getType()), selectorBuffer, true); // Define the shift vectors (pos/neg) as sparse tensor constants. DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift); @@ -285,8 +286,9 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter, std::vector gradientTensors; gradientTensors.reserve(gradResTypes.size()); for (Value gradientBuffer : gradientBuffers) { - gradientTensors.push_back( - rewriter.create(loc, gradientBuffer, true)); + gradientTensors.push_back(rewriter.create( + loc, memref::getTensorTypeFromMemRefType(gradientBuffer.getType()), + gradientBuffer, true)); } op->setOperands(gradientTensors); } diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp index 70f623fc29..d34dd89974 100644 --- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp @@ -108,7 +108,9 @@ struct PostprocessForwardOp : public OpRewritePattern { // Insert new argIn in an interleaving way. size_t idx = 0; for (auto ty : newArgInTypes) { - op.insertArgument(2 * idx + 1, ty, {}, op.getLoc()); + if (failed(op.insertArgument(2 * idx + 1, ty, {}, op.getLoc()))) { + return failure(); + } idx++; } // Append newArgRes. @@ -117,8 +119,11 @@ struct PostprocessForwardOp : public OpRewritePattern { /*values=*/op.getNumArguments()); SmallVector argAttrs{appendingSize}; SmallVector argLocs{appendingSize, op.getLoc()}; - op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs); + if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) { + return failure(); + } op.setFunctionType(forwardTy); + return success(); }); op.walk([&](ReturnOp returnOp) { @@ -195,7 +200,9 @@ struct PostprocessReverseOp : public OpRewritePattern { // Insert new argIn in an interleaving way. size_t idx = 0; for (auto ty : newArgInTypes) { - op.insertArgument(2 * idx, ty, {}, op.getLoc()); + if (failed(op.insertArgument(2 * idx, ty, {}, op.getLoc()))) { + return failure(); + } idx++; } // Append newArgRes. @@ -204,8 +211,11 @@ struct PostprocessReverseOp : public OpRewritePattern { /*values=*/0); SmallVector argAttrs{appendingSize}; SmallVector argLocs{appendingSize, op.getLoc()}; - op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs); + if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) { + return failure(); + } op.setFunctionType(reverseTy); + return success(); }); op.walk([&](ReturnOp returnOp) { diff --git a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp index ef6b572583..08c082a316 100644 --- a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp +++ b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp @@ -19,11 +19,11 @@ using namespace mlir; -void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder) +LogicalResult catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder) { if (callee.getNumResults() == 0) { // Callee is already in destination-passing style - return; + return success(); } MLIRContext *ctx = callee.getContext(); @@ -48,7 +48,7 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder & if (callee.isDeclaration()) { // If the function does not have a body, we are done after modifying the function type. callee.setFunctionType(dpsFunctionType); - return; + return success(); } // Insert the new output arguments to the function. @@ -60,7 +60,9 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder & // insertArguments modifies the function type, so we need to update the function type *after* // inserting the arguments. - callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs); + if (failed(callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs))) { + return failure(); + } callee.setFunctionType(dpsFunctionType); // Update return sites to copy over the memref that would have been returned to the output. @@ -83,4 +85,6 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder & } returnOp.getOperandsMutable().assign(nonMemRefReturns); }); + + return success(); } diff --git a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp index 5768982de5..4d1164c140 100644 --- a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp +++ b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp @@ -44,7 +44,7 @@ Value buildTensorLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper // Initialize the result tensor FloatType elementType = cast(resultType.getElementType()); Value zero = builder.create( - loc, APFloat::getZero(elementType.getFloatSemantics()), elementType); + loc, elementType, APFloat::getZero(elementType.getFloatSemantics())); Value result = builder.create(loc, resultType.getShape(), resultType.getElementType()); result = builder.create(loc, zero, result).getResult(0); diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp index 4a0312478e..9498f306e8 100644 --- a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp @@ -58,15 +58,16 @@ struct QubitUnitaryOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto qubitUnitaryOp = cast(op); Location loc = op->getLoc(); auto tensorType = cast(qubitUnitaryOp.getMatrix().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toMemrefOp = - rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix()); - auto memref = toMemrefOp.getResult(); + auto toBufferOp = + rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix()); + auto memref = toBufferOp.getResult(); bufferization::replaceOpWithNewBufferizedOp( rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(), qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(), @@ -101,15 +102,16 @@ struct HermitianOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto hermitianOp = cast(op); Location loc = op->getLoc(); auto tensorType = cast(hermitianOp.getMatrix().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toMemrefOp = - rewriter.create(loc, memrefType, hermitianOp.getMatrix()); - auto memref = toMemrefOp.getResult(); + auto toBufferOp = + rewriter.create(loc, memrefType, hermitianOp.getMatrix()); + auto memref = toBufferOp.getResult(); auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref, hermitianOp.getQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs()); @@ -143,15 +145,16 @@ struct HamiltonianOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto hamiltonianOp = cast(op); Location loc = op->getLoc(); auto tensorType = cast(hamiltonianOp.getCoeffs().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toMemrefOp = - rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs()); - auto memref = toMemrefOp.getResult(); + auto toBufferOp = + rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs()); + auto memref = toBufferOp.getResult(); auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref, hamiltonianOp.getTerms()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs()); @@ -187,7 +190,8 @@ struct SampleOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto sampleOp = cast(op); Location loc = op->getLoc(); @@ -237,7 +241,8 @@ struct CountsOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto countsOp = cast(op); Location loc = op->getLoc(); @@ -297,7 +302,8 @@ struct ProbsOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto probsOp = cast(op); Location loc = op->getLoc(); @@ -350,7 +356,8 @@ struct StateOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto stateOp = cast(op); Location loc = op->getLoc(); @@ -401,16 +408,17 @@ struct SetStateOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto setStateOp = cast(op); Location loc = op->getLoc(); auto tensorType = cast(setStateOp.getInState().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toMemrefOp = - rewriter.create(loc, memrefType, setStateOp.getInState()); - auto memref = toMemrefOp.getResult(); + auto toBufferOp = + rewriter.create(loc, memrefType, setStateOp.getInState()); + auto memref = toBufferOp.getResult(); auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(), memref, setStateOp.getInQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); @@ -443,16 +451,17 @@ struct SetBasisStateOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const bufferization::BufferizationOptions &options) const + const bufferization::BufferizationOptions &options, + bufferization::BufferizationState &state) const { auto setBasisStateOp = cast(op); Location loc = op->getLoc(); auto tensorType = cast(setBasisStateOp.getBasisState().getType()); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto toMemrefOp = rewriter.create( + auto toBufferOp = rewriter.create( loc, memrefType, setBasisStateOp.getBasisState()); - auto memref = toMemrefOp.getResult(); + auto memref = toBufferOp.getResult(); auto newSetStateOp = rewriter.create( loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits()); bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits()); diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp index eaf30e2829..8dbf401c46 100644 --- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp @@ -44,7 +44,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe } return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()), type, rewriter.create(loc, glb), - ArrayRef{0, 0}, true); + ArrayRef{0, 0}, + LLVM::GEPNoWrapFlags::inbounds); } /** @@ -80,13 +81,17 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter auto structType = LLVM::LLVMStructType::getLiteral(ctx, {boolType, sizeType, ptrType, ptrType}); auto modifiersPtr = catalyst::getStaticAlloca(loc, rewriter, structType, 1).getResult(); auto adjointPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 0}, true); + llvm::ArrayRef{0, 0}, + LLVM::GEPNoWrapFlags::inbounds); auto numControlledPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, - llvm::ArrayRef{0, 1}, true); - auto controlledWiresPtr = rewriter.create( - loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 2}, true); - auto controlledValuesPtr = rewriter.create( - loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 3}, true); + llvm::ArrayRef{0, 1}, + LLVM::GEPNoWrapFlags::inbounds); + auto controlledWiresPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 2}, + LLVM::GEPNoWrapFlags::inbounds); + auto controlledValuesPtr = rewriter.create(loc, ptrType, structType, modifiersPtr, + llvm::ArrayRef{0, 3}, + LLVM::GEPNoWrapFlags::inbounds); Value ctrlPtr = nullPtr; Value valuePtr = nullPtr; @@ -98,13 +103,15 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter for (int i = 0; static_cast(i) < controlledQubits.size(); i++) { { auto itemPtr = rewriter.create(loc, ptrType, ptrType, ctrlPtr, - llvm::ArrayRef{i}, true); + llvm::ArrayRef{i}, + LLVM::GEPNoWrapFlags::inbounds); auto qubit = controlledQubits[i]; rewriter.create(loc, qubit, itemPtr); } { auto itemPtr = rewriter.create(loc, ptrType, boolType, valuePtr, - llvm::ArrayRef{i}, true); + llvm::ArrayRef{i}, + LLVM::GEPNoWrapFlags::inbounds); auto value = controlledValues[i]; rewriter.create(loc, value, itemPtr); } @@ -1012,7 +1019,7 @@ struct SetStateOpPattern : public OpConversionPattern { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); ModuleOp moduleOp = op->getParentOfType(); - auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetState", + auto func = mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetState", {ptrTy, i64}, voidTy, isVarArg) .value(); @@ -1052,9 +1059,10 @@ struct SetBasisStateOpPattern : public OpConversionPattern { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); ModuleOp moduleOp = op->getParentOfType(); - auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState", - {ptrTy, i64}, voidTy, isVarArg) - .value(); + auto func = + mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetBasisState", + {ptrTy, i64}, voidTy, isVarArg) + .value(); SmallVector args; diff --git a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp index 4c91fbf67c..f4b8a2f9ca 100644 --- a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp +++ b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp @@ -214,13 +214,9 @@ struct EmitCatalystPyInterfacePass patterns.add(context); GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; - config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - config.maxIterations = 1; - // TODO: Update to the following lines the next time we update llvm - // config.setStrictness(GreedyRewriteStrictness::ExistingOps); - // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); - // config.setMaxIterations(1); + config.setStrictness(GreedyRewriteStrictness::ExistingOps); + config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + config.setMaxIterations(1); auto op = getOperation(); SmallVector targets; diff --git a/mlir/llvm-project b/mlir/llvm-project index 179d30f8c3..f8cb7987c6 160000 --- a/mlir/llvm-project +++ b/mlir/llvm-project @@ -1 +1 @@ -Subproject commit 179d30f8c3fddd3c85056fd2b8e877a4a8513158 +Subproject commit f8cb7987c64dcffb72414a40560055cb717dbf74 diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo index 617a9361d1..1dd2e71331 160000 --- a/mlir/mlir-hlo +++ b/mlir/mlir-hlo @@ -1 +1 @@ -Subproject commit 617a9361d186199480c080c9e8c474a5e30c22d1 +Subproject commit 1dd2e71331014ae0373f6bf900ce6be393357190 diff --git a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch index 9f80c60a75..8746e54c73 100644 --- a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch +++ b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch @@ -1,8 +1,8 @@ diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp -index 85050315..414318eb 100644 +index 7c234dd4..846f68b4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp -@@ -3940,14 +3940,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { +@@ -3942,14 +3942,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { case Intrinsic::nearbyint: case Intrinsic::round: case Intrinsic::sqrt: diff --git a/mlir/patches/llvm-bufferization-segfault.patch b/mlir/patches/llvm-bufferization-segfault.patch new file mode 100644 index 0000000000..e894516820 --- /dev/null +++ b/mlir/patches/llvm-bufferization-segfault.patch @@ -0,0 +1,27 @@ +diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +index 453ed43bcad..dff994729a4 100644 +--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp ++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +@@ -89,16 +89,12 @@ static FuncOp getCalledFunction(CallOpInterface callOp, + /// Return the FuncOp called by `callOp`. + static FuncOp getCalledFunction(CallOpInterface callOp, + const AnalysisState &state) { +- auto &oneShotAnalysisState = static_cast(state); +- +- if (auto *funcAnalysisState = +- oneShotAnalysisState.getExtension()) { +- // Use the cached symbol tables. +- return getCalledFunction(callOp, funcAnalysisState->symbolTables); +- } +- +- SymbolTableCollection symbolTables; +- return getCalledFunction(callOp, symbolTables); ++ SymbolRefAttr sym = ++ llvm::dyn_cast_if_present(callOp.getCallableForCallee()); ++ if (!sym) ++ return nullptr; ++ return dyn_cast_or_null( ++ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); + } + + /// Get FuncAnalysisState. diff --git a/mlir/patches/mhlo-add-back-necessary-passes.patch b/mlir/patches/mhlo-add-back-necessary-passes.patch index b56ede8dd5..c430adae43 100644 --- a/mlir/patches/mhlo-add-back-necessary-passes.patch +++ b/mlir/patches/mhlo-add-back-necessary-passes.patch @@ -7,12 +7,12 @@ Subject: [PATCH] restore the removed mhlo passes we need: --- mhlo/transforms/CMakeLists.txt | 6 + .../legalize_control_flow.cc | 288 +++++++++ - .../transforms/legalize_sort/legalize_sort.cc | 577 ++++++++++++++++++ + .../transforms/legalize_sort/legalize_sort.cc | 578 ++++++++++++++++++ .../legalize_to_standard.cc | 243 ++++++++ .../legalize_to_standard_patterns.td | 92 +++ mhlo/transforms/mhlo_passes.td | 19 + mhlo/transforms/passes.h | 4 + - 7 files changed, 1229 insertions(+) + 7 files changed, 1230 insertions(+) create mode 100644 mhlo/transforms/legalize_control_flow/legalize_control_flow.cc create mode 100644 mhlo/transforms/legalize_sort/legalize_sort.cc create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard.cc @@ -342,7 +342,7 @@ new file mode 100644 index 00000000..8ba9de9a --- /dev/null +++ b/mhlo/transforms/legalize_sort/legalize_sort.cc -@@ -0,0 +1,577 @@ +@@ -0,0 +1,578 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); @@ -883,8 +883,9 @@ index 00000000..8ba9de9a + + SmallVector outputTensors; + for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) { ++ Value s = b.create(parity, out1, out0).getResult(); + outputTensors.push_back(b.create( -+ b.create(parity, out1, out0), /*restrict=*/true)); ++ memref::getTensorTypeFromMemRefType(s.getType()), s, /*restrict=*/true)); + } + + rewriter.replaceOp(op, outputTensors); diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch index f78200bdab..32ce71061f 100644 --- a/mlir/patches/mhlo-remove-shardy.patch +++ b/mlir/patches/mhlo-remove-shardy.patch @@ -84,7 +84,7 @@ index cabd6a9f..2e64b4ed 100644 patterns->add(context); patterns->add(context); patterns->add(context); -- populateSdyShapeRefinementPatterns(patterns, context); +- populateSdyShapeRefinementPatterns(context, patterns); }; if (failed(stablehlo::refineEntryFunction(*context, func, @@ -92,7 +92,7 @@ index cabd6a9f..2e64b4ed 100644 patterns->add(context); patterns->add(context); patterns->add(context); -- populateSdyShapeRefinementPatterns(patterns, context); +- populateSdyShapeRefinementPatterns(context, patterns); } } // namespace stablehlo_ext diff --git a/mlir/patches/mhlo-rename-sort.patch b/mlir/patches/mhlo-rename-sort.patch new file mode 100644 index 0000000000..c356cc35e3 --- /dev/null +++ b/mlir/patches/mhlo-rename-sort.patch @@ -0,0 +1,15 @@ +diff --git a/utils/cycle_detector.cc b/utils/cycle_detector.cc +index e3901ae88..890f39654 100644 +--- a/utils/cycle_detector.cc ++++ b/utils/cycle_detector.cc +@@ -199,8 +199,8 @@ static void backwardDfs(GraphCycles::Rep* r, int32_t n, int32_t lowerBound) { + // Recomputes rank assignments to make them compatible with the edges (producer + // has smaller rank than its consumer) + static void reorder(GraphCycles::Rep* r) { +- sort(r->nodes, &r->deltab); +- sort(r->nodes, &r->deltaf); ++ mlir::sort(r->nodes, &r->deltab); ++ mlir::sort(r->nodes, &r->deltaf); + + // Adds contents of delta lists to list (backwards deltas first). + r->list.clear(); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 6394246da0..e3a8b603eb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -3,6 +3,8 @@ include(AddMLIRPython) # TODO: Add an upstream cmake param for this vs having a global here. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir_quantum.") +# Ignore nanobind warnings +add_compile_options(-w) ################################################################################ # Declare Dialect Sources diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index effc229a64..67f186943d 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -23,7 +23,7 @@ func.func @dbprint_val(%arg0: tensor) { - // CHECK: %0 = bufferization.to_memref %arg0 + // CHECK: %0 = bufferization.to_buffer %arg0 // CHECK: "catalyst.print"(%0) : (memref) -> () "catalyst.print"(%arg0) : (tensor) -> () @@ -34,7 +34,7 @@ func.func @dbprint_val(%arg0: tensor) { func.func @dbprint_memref(%arg0: tensor) { - // CHECK: %0 = bufferization.to_memref %arg0 + // CHECK: %0 = bufferization.to_buffer %arg0 // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> () "catalyst.print"(%arg0) {print_descriptor} : (tensor) -> () @@ -54,7 +54,7 @@ func.func @dbprint_str() { // ----- func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () @@ -72,7 +72,7 @@ func.func @custom_call_copy(%arg0: tensor<2x3xf64>) -> tensor<2x2xf64> { // COM: e.g. coming from tensor subviews // COM: a copy needs to be performed because the kernels only allow for contiguous arrays as inputs // - // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[subview:%.+]] = memref.subview [[sourceAlloc]] // CHECK-SAME: memref<2x3xf64> to memref<2x2xf64, strided<[3, 1]>> // CHECK: [[copyAlloc:%.+]] = memref.alloc() : memref<2x2xf64> @@ -106,7 +106,7 @@ module @test1 { // CHECK-LABEL: @foo( // CHECK-SAME: [[arg0:%.+]]: tensor) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : tensor to memref + // CHECK-DAG: [[memref0:%.+]] = bufferization.to_buffer [[arg0]] : tensor to memref // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> () %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor) diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir index 20ce7c6bc7..9a9fc17102 100644 --- a/mlir/test/Catalyst/ConversionTest.mlir +++ b/mlir/test/Catalyst/ConversionTest.mlir @@ -146,13 +146,13 @@ module @test1 { // CHECK-SAME: [[arg0:%.+]]: tensor // CHECK-SAME:) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK: [[memref0:%.+]] = bufferization.to_memref [[arg0]] + // CHECK: [[memref0:%.+]] = bufferization.to_buffer [[arg0]] // CHECK: [[ptr0:%.+]] = llvm.alloca {{.*}} // CHECK: [[ptr1:%.+]] = llvm.alloca {{.*}} // CHECK: [[struct0:%.+]] = builtin.unrealized_conversion_cast [[memref0]] // CHECK: [[tensor1:%.+]] = bufferization.alloc_tensor() - // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]] + // CHECK: [[memref1:%.+]] = bufferization.to_buffer [[tensor1]] // CHECK: [[struct1:%.+]] = builtin.unrealized_conversion_cast [[memref1]] // CHECK: llvm.store [[struct0]], [[ptr1]] @@ -160,9 +160,9 @@ module @test1 { // call @callback_1([[ptr0]], [[ptr1]]) - %0 = bufferization.to_memref %arg0 : tensor to memref + %0 = bufferization.to_buffer %arg0 : tensor to memref %1 = bufferization.alloc_tensor() {memory_space = 0 : i64} : tensor - %2 = bufferization.to_memref %1 : tensor to memref + %2 = bufferization.to_buffer %1 : tensor to memref catalyst.callback_call @callback_1(%0, %2) : (memref, memref) -> () diff --git a/mlir/test/Gradient/BufferizationTest.mlir b/mlir/test/Gradient/BufferizationTest.mlir index 4a8f9a246e..8e84995888 100644 --- a/mlir/test/Gradient/BufferizationTest.mlir +++ b/mlir/test/Gradient/BufferizationTest.mlir @@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>) // CHECK-LABEL: @adjoint_with_tensor_arg func.func @adjoint_with_tensor_arg(%arg0: tensor<2xf64>, %arg1: index) { - // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64> + // CHECK: [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64> // CHECK: [[alloc:%.+]] = memref.alloc(%arg1) : memref // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref) : (memref<2xf64>) -> () %grad = gradient.adjoint @circuit(%arg0) size(%arg1) : (tensor<2xf64>) -> tensor @@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>) // CHECK-LABEL: @adjoint_with_multiple_results func.func @adjoint_with_multiple_results(%arg0: tensor<2xf64>, %arg1: index) { - // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64> + // CHECK: [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64> // CHECK: [[alloc0:%.+]] = memref.alloc(%arg1) : memref // CHECK: [[alloc1:%.+]] = memref.alloc(%arg1) : memref // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]] @@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64) // CHECK-LABEL: @backprop_scalar_in func.func @backprop_scalar_in(%arg0: f64, %arg1: tensor) { - // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref + // CHECK: [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref // CHECK: [[dim1:%.+]] = memref.dim [[cotangentSource]] // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]] @@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor) // CHECK-LABEL: @backprop_tensor_in func.func @backprop_tensor_in(%arg0: tensor, %arg1: tensor) { - // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor to memref - // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref + // CHECK-DAG: [[argSource:%.+]] = bufferization.to_buffer %arg0 : tensor to memref + // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref // CHECK: [[dim2:%.+]] = memref.dim [[cotangentSource]] // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]] @@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>) // CHECK-LABEL: @backprop_multiple_tensors_in func.func @backprop_multiple_tensors_in(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>, %arg2: tensor) { - // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64> - // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64> + // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_buffer %arg0 : tensor<10xf64> to memref<10xf64> + // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_buffer %arg1 : tensor<2xf64> to memref<2xf64> // CHECK: memref.alloc // CHECK: memref.copy // CHECK: [[argShadow1:%.+]] = memref.alloc() : memref<10xf64> @@ -171,8 +171,8 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor, ten // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref - // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64> + // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref + // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref, memref<2xf64> %0:2 = func.call @callback_fn_fwd(%arg0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) @@ -192,7 +192,7 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor, %arg1: tensor<2xf64>) // CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64> // CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor) -> tensor<2xf64> - // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64> + // CHECK: [[res:%.+]] = bufferization.to_buffer [[callOut]] : tensor<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = true} [[res]] : memref<2xf64> %0 = func.call @callback_fn_vjp(%arg1, %arg0) : (tensor<2xf64>, tensor) -> tensor<2xf64> diff --git a/mlir/test/Gradient/FiniteDifferenceTest.mlir b/mlir/test/Gradient/FiniteDifferenceTest.mlir index 13af9e7956..37d70471d6 100644 --- a/mlir/test/Gradient/FiniteDifferenceTest.mlir +++ b/mlir/test/Gradient/FiniteDifferenceTest.mlir @@ -161,7 +161,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6 // CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1) // CHECK: [[DIFF:%.+]] = tensor.generate // CHECK-NEXT: ^bb0(%arg2: index, %arg3: index): - // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]] // CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]] // CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3] @@ -188,7 +188,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6 // CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1) // CHECK: [[DIFF:%.+]] = tensor.generate // CHECK-NEXT: ^bb0(%arg2: index, %arg3: index): - // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]] // CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]] // CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3] @@ -227,7 +227,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at // CHECK: [[BASE:%.+]]:2 = call @funcMultiRes(%arg0) // CHECK: [[DIFF:%.+]] = tensor.generate // CHECK-NEXT: ^bb0(%arg1: index): - // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]] // CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]] // CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg1] @@ -239,7 +239,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at // CHECK: [[R0:%.+]] = arith.divf [[DIFF]] // CHECK: [[DIFF:%.+]] = tensor.generate // CHECK-NEXT: ^bb0(%arg1: index, %arg2: index): - // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]] // CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]] // CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg2] @@ -279,7 +279,7 @@ func.func private @funcDynamicTensor(%arg0: tensor) -> tensor<2x?xf64> // CHECK: [[DIFF:%.+]] = tensor.generate [[DDIM0]], [[DDIM1]] // CHECK-NEXT: ^bb0([[i0:%.+]]: index, [[i1:%.+]]: index, [[i2:%.+]]: index, [[i3:%.+]]: index): - // CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[MEMREF:%.+]] = bufferization.to_buffer %arg0 // CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]] // CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]] // CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][[[i2]], [[i3]]] diff --git a/mlir/test/Gradient/PostProcessingTest.mlir b/mlir/test/Gradient/PostProcessingTest.mlir index 2403372410..9ae25800f1 100644 --- a/mlir/test/Gradient/PostProcessingTest.mlir +++ b/mlir/test/Gradient/PostProcessingTest.mlir @@ -25,15 +25,15 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: memref<2xf64>) -> (memref, mem // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64> // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref - // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64> + // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref + // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64> // CHECK: memref.copy [[res0]], %arg2 : memref to memref // CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64> %0 = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64> %1:2 = func.call @callback_fn_fwd(%0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>) - %2 = bufferization.to_memref %1#0 : tensor to memref - %3 = bufferization.to_memref %1#1 : tensor<2xf64> to memref<2xf64> + %2 = bufferization.to_buffer %1#0 : tensor to memref + %3 = bufferization.to_buffer %1#1 : tensor<2xf64> to memref<2xf64> gradient.return {empty = false} %2, %3 : memref, memref<2xf64> } @@ -50,13 +50,13 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: memref, %arg1: memref<2xf64>) // CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64> to tensor<2xf64> // CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref to tensor // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[tape]], [[cotan]]) : (tensor<2xf64>, tensor) -> tensor<2xf64> - // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64> + // CHECK: [[res:%.+]] = bufferization.to_buffer [[callOut]] : tensor<2xf64> to memref<2xf64> // CHECK: memref.copy [[res]], %arg1 : memref<2xf64> to memref<2xf64> // CHECK: gradient.return {empty = true} %0 = bufferization.to_tensor %arg1 : memref<2xf64> to tensor<2xf64> %1 = bufferization.to_tensor %arg0 : memref to tensor %2 = func.call @callback_fn_vjp(%0, %1) : (tensor<2xf64>, tensor) -> tensor<2xf64> - %3 = bufferization.to_memref %2 : tensor<2xf64> to memref<2xf64> + %3 = bufferization.to_buffer %2 : tensor<2xf64> to memref<2xf64> gradient.return {empty = true} %3 : memref<2xf64> } diff --git a/mlir/test/Quantum/BufferizationTest.mlir b/mlir/test/Quantum/BufferizationTest.mlir index 7b94860c44..ab619464a6 100644 --- a/mlir/test/Quantum/BufferizationTest.mlir +++ b/mlir/test/Quantum/BufferizationTest.mlir @@ -15,7 +15,7 @@ // RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> + // CHECK: [[memref:%.+]] = bufferization.to_buffer %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> // CHECK: {{%.+}} = quantum.unitary([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.bit %out_qubits = quantum.unitary(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.bit @@ -25,7 +25,7 @@ func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { // ----- func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> + // CHECK: [[memref:%.+]] = bufferization.to_buffer %arg1 : tensor<2x2xcomplex> to memref<2x2xcomplex> // CHECK: {{%.+}} = quantum.hermitian([[memref]] : memref<2x2xcomplex>) %arg0 : !quantum.obs %obs = quantum.hermitian(%matrix : tensor<2x2xcomplex>) %q0 : !quantum.obs @@ -35,7 +35,7 @@ func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex>) { // ----- func.func @hamiltonian(%obs: !quantum.obs, %coeffs: tensor<1xf64>){ - // CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<1xf64> to memref<1xf64> + // CHECK: [[memref:%.+]] = bufferization.to_buffer %arg1 : tensor<1xf64> to memref<1xf64> // CHECK: {{%.+}} = quantum.hamiltonian([[memref]] : memref<1xf64>) %arg0 : !quantum.obs %hamil = quantum.hamiltonian(%coeffs: tensor<1xf64>) %obs : !quantum.obs diff --git a/mlir/test/cli/DumpPipeline.mlir b/mlir/test/cli/DumpPipeline.mlir index 0995df2040..ce031c18d3 100644 --- a/mlir/test/cli/DumpPipeline.mlir +++ b/mlir/test/cli/DumpPipeline.mlir @@ -25,12 +25,19 @@ func.func @foo() { // CHECK: builtin.module // CHECK-CUSTOM: Pass Manager with 2 passes -// CHECK-CUSTOM: builtin.module(split-multiple-tapes,apply-transform-sequence) +// CHECK-CUSTOM: builtin.module( +// CHECK-CUSTOM: split-multiple-tapes, +// CHECK-CUSTOM: apply-transform-sequence +// CHECK-CUSTOM: ) // CHECK-CUSTOM: Pass Manager with 1 passes -// CHECK-CUSTOM: builtin.module(inline-nested-module{stop-after-step=0}) +// CHECK-CUSTOM: builtin.module( +// CHECK-CUSTOM: inline-nested-module{stop-after-step=0} +// CHECK-CUSTOM: ) // CHECK-ONE-PASS: Pass Manager with 1 passes -// CHECK-ONE-PASS: builtin.module(cse) +// CHECK-ONE-PASS: builtin.module( +// CHECK-ONE-PASS: cse +// CHECK-ONE-PASS: ) // CHECK-FAIL: --catalyst-pipeline option can't be used with individual pass options or -pass-pipeline. // CHECK-FAIL: Compilation failed