diff --git a/.dep-versions b/.dep-versions
index d755906e6d..849b989020 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -2,9 +2,9 @@
 # To update JAX version alongside compatible dependency tags, run the following script:
 # python3 .github/workflows/set_dep_versions.py {JAX_version}
 jax=0.6.2
-mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
-llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
-enzyme=v0.0.180
+mhlo=1dd2e71331014ae0373f6bf900ce6be393357190
+llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
+enzyme=v0.0.186
 
 # Always remove custom PL/LQ versions before release.
 
diff --git a/.github/workflows/build-wheel-linux-arm64.yaml b/.github/workflows/build-wheel-linux-arm64.yaml
index df53cd5880..ef4b3d04ce 100644
--- a/.github/workflows/build-wheel-linux-arm64.yaml
+++ b/.github/workflows/build-wheel-linux-arm64.yaml
@@ -108,6 +108,12 @@ jobs:
         ref: ${{ needs.constants.outputs.llvm_version }}
         path: ${{ github.workspace }}/mlir/llvm-project
 
+    - name: Patch LLVM Source
+      if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+      run: |
+        cd $GITHUB_WORKSPACE/mlir/llvm-project
+        git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
     - name: Clone MHLO Submodule
       if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
       uses: actions/checkout@v4
@@ -122,6 +128,7 @@ jobs:
         cd $GITHUB_WORKSPACE/mlir/mlir-hlo
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+        git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
 
     - name: Clone Enzyme Submodule
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -134,9 +141,8 @@ jobs:
     - name: Patch Enzyme Source
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
       run: |
-        export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-        export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
-        patch -p1 $TARGET_FILE $PATCH_FILE
+        cd $GITHUB_WORKSPACE/mlir/Enzyme
+        git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
 
     # Cache external project builds
     - name: Restore LLVM Build
diff --git a/.github/workflows/build-wheel-linux-x86_64.yaml b/.github/workflows/build-wheel-linux-x86_64.yaml
index 35be761c99..bde5f9cbbc 100644
--- a/.github/workflows/build-wheel-linux-x86_64.yaml
+++ b/.github/workflows/build-wheel-linux-x86_64.yaml
@@ -127,6 +127,12 @@ jobs:
         ref: ${{ needs.constants.outputs.llvm_version }}
         path: ${{ github.workspace }}/mlir/llvm-project
 
+    - name: Patch LLVM Source
+      if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+      run: |
+        cd $GITHUB_WORKSPACE/mlir/llvm-project
+        git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
     - name: Clone MHLO Submodule
       if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
       uses: actions/checkout@v4
@@ -141,6 +147,7 @@ jobs:
         cd $GITHUB_WORKSPACE/mlir/mlir-hlo
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+        git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
 
     - name: Clone Enzyme Submodule
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -153,9 +160,8 @@ jobs:
     - name: Patch Enzyme Source
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
       run: |
-        export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-        export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
-        patch -p1 $TARGET_FILE $PATCH_FILE
+        cd $GITHUB_WORKSPACE/mlir/Enzyme
+        git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
 
     # Cache external project builds
     - name: Restore LLVM Build
diff --git a/.github/workflows/build-wheel-macos-arm64.yaml b/.github/workflows/build-wheel-macos-arm64.yaml
index d6d2e9d215..fdc101e3a8 100644
--- a/.github/workflows/build-wheel-macos-arm64.yaml
+++ b/.github/workflows/build-wheel-macos-arm64.yaml
@@ -113,6 +113,12 @@ jobs:
         ref: ${{ needs.constants.outputs.llvm_version }}
         path: ${{ github.workspace }}/mlir/llvm-project
 
+    - name: Patch LLVM Source
+      if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
+      run: |
+        cd $GITHUB_WORKSPACE/mlir/llvm-project
+        git apply $GITHUB_WORKSPACE/mlir/patches/llvm-bufferization-segfault.patch
+
     - name: Clone MHLO Submodule
       if: steps.cache-mhlo-source.outputs.cache-hit != 'true'
       uses: actions/checkout@v4
@@ -127,6 +133,7 @@ jobs:
         cd $GITHUB_WORKSPACE/mlir/mlir-hlo
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
         git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
+        git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-rename-sort.patch
 
     - name: Clone Enzyme Submodule
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
@@ -139,9 +146,8 @@ jobs:
     - name: Patch Enzyme Source
       if: steps.cache-enzyme-source.outputs.cache-hit != 'true'
       run: |
-        export TARGET_FILE=$GITHUB_WORKSPACE/mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-        export PATCH_FILE=$GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
-        patch -p1 $TARGET_FILE $PATCH_FILE
+        cd $GITHUB_WORKSPACE/mlir/Enzyme
+        git apply $GITHUB_WORKSPACE/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
 
     # Cache external project builds
     - name: Restore LLVM Build
diff --git a/Makefile b/Makefile
index e3309167f0..68d81ae519 100644
--- a/Makefile
+++ b/Makefile
@@ -282,6 +282,9 @@ clean-plugin:
 clean-llvm:
 	$(MAKE) -C mlir clean-llvm
 
+reset-llvm:
+	$(MAKE) -C mlir reset-llvm
+
 clean-mhlo:
 	$(MAKE) -C mlir clean-mhlo
 
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 32688e2dab..2c03b31a80 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -12,6 +12,16 @@
 * The JAX version used by Catalyst is updated to 0.6.2.
   [(#1897)](https://github.com/PennyLaneAI/catalyst/pull/1897)
 
+* The version of LLVM, mlir-hlo, and Enzyme used by Catalyst has been updated.
+  [(#1916)](https://github.com/PennyLaneAI/catalyst/pull/1916)
+
+  The LLVM version has been updated to
+  [commit f8cb798](https://github.com/llvm/llvm-project/tree/f8cb7987c64dcffb72414a40560055cb717dbf74).
+  The mlir-hlo version has been updated to
+  [commit 1dd2e71](https://github.com/tensorflow/mlir-hlo/tree/1dd2e71331014ae0373f6bf900ce6be393357190).
+  The Enzyme version has been updated to
+  [v0.0.186](https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.186).
+
 
Deprecations 👋
 
 Bug fixes 🐛
diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py
index d2bfbc74bb..7dc1382593 100644
--- a/frontend/catalyst/jax_extras/lowering.py
+++ b/frontend/catalyst/jax_extras/lowering.py
@@ -16,6 +16,7 @@
 from __future__ import annotations
 
 import logging
+import textwrap
 
 import jax
 from jax._src.dispatch import jaxpr_replicas
@@ -38,6 +39,7 @@
 
 import catalyst
 from catalyst.logging import debug_logger
+from catalyst.utils.exceptions import CompileError
 from catalyst.utils.patching import Patcher
 
 # pylint: disable=protected-access
@@ -165,3 +167,57 @@ def custom_lower_jaxpr_to_module(
                 worklist += [*op.body.operations]
 
     return ctx.module, ctx.context
+
+
+def get_mlir_attribute_from_pyval(value):
+    """
+    Given a value of any type, construct an mlir attribute of corresponding type.
+
+    We set up the context and location outside because recursive calls to this function
+    will segfault if multiple `Context()`s are instantiated.
+    """
+
+    attr = None
+    match value:
+        case bool():
+            attr = ir.BoolAttr.get(value)
+
+        case int():
+            if -9223372036854775808 <= value < 0:  # 2**63
+                attr = ir.IntegerAttr.get(ir.IntegerType.get_signed(64), value)
+            elif 0 <= value < 18446744073709551616:  # = 2**64
+                attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value)
+            else:
+                raise CompileError(
+                    textwrap.dedent(
+                        """
+                    Large interger attributes currently not supported in MLIR,
+                    see https://github.com/llvm/llvm-project/issues/128072
+                    """
+                    )
+                )
+
+        case float():
+            attr = ir.FloatAttr.get(ir.F64Type.get(), value)
+
+        case str():
+            attr = ir.StringAttr.get(value)
+
+        case list() | tuple():
+            element_attrs = [get_mlir_attribute_from_pyval(elem) for elem in value]
+            attr = ir.ArrayAttr.get(element_attrs)
+
+        case dict():
+            named_attrs = {}
+            for k, v in value.items():
+                if not isinstance(k, str):
+                    raise CompileError(
+                        f"Dictionary keys for MLIR DictionaryAttr must be strings, got: {type(k)}"
+                    )
+                named_attrs[k] = get_mlir_attribute_from_pyval(v)
+            attr = ir.DictAttr.get(named_attrs)
+
+        case _:
+            raise CompileError(f"Cannot convert Python type {type(value)} to an MLIR attribute.")
+
+    return attr
diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py
index cd2d27eb53..effa498b45 100644
--- a/frontend/catalyst/jax_primitives_utils.py
+++ b/frontend/catalyst/jax_primitives_utils.py
@@ -280,7 +280,11 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
             for _pass in pipeline:
                 options = _pass.get_options()
                 apply_registered_pass_op = ApplyRegisteredPassOp(
-                    result=transform_mod_type, target=target, pass_name=_pass.name, options=options
+                    result=transform_mod_type,
+                    target=target,
+                    pass_name=_pass.name,
+                    options=options,
+                    dynamic_options={},
                 )
                 target = apply_registered_pass_op.result
             transform_yield_op = YieldOp(operands_=[])  # pylint: disable=unused-variable
diff --git a/frontend/catalyst/passes/pass_api.py b/frontend/catalyst/passes/pass_api.py
index 423f8c5269..b0ed5db380 100644
--- a/frontend/catalyst/passes/pass_api.py
+++ b/frontend/catalyst/passes/pass_api.py
@@ -19,6 +19,7 @@
 
 import pennylane as qml
 
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
 from catalyst.tracing.contexts import EvaluationContext
 
 PipelineDict: TypeAlias = dict[str, dict[str, str]]
@@ -286,23 +287,18 @@ def __init__(self, name: str, *options: list[str], **valued_options: dict[str, s
 
     def get_options(self):
         """
-        Stringify options according to what mlir-opt expects.
-
-          ApplyRegisteredPassOp expects options to be a single StringAttr
-          which follows the same format as the one used with mlir-opt.
-
-        https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop
-
-          Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...},
-          i.e., use space-separated key=value pairs for each option.
+        Build a dictionary mapping option names to MLIR attributes.
+        ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes.
+        See https://github.com/llvm/llvm-project/pull/143159
+        """
+        options_dict = {}
+        for option in self.options:
+            options_dict[str(option)] = get_mlir_attribute_from_pyval(True)
 
-        https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options
+        for option, value in self.valued_options.items():
+            options_dict[str(option)] = get_mlir_attribute_from_pyval(value)
 
-        Experimentally we found that single-options also work without values.
-        """
-        retval = " ".join(f"{str(option)}" for option in self.options)
-        retval2 = " ".join(f"{str(key)}={str(value)}" for key, value in self.valued_options.items())
-        return " ".join([retval, retval2]).strip()
+        return options_dict
 
     def __repr__(self):
         return (
diff --git a/frontend/test/lit/test_mlir_plugin.py b/frontend/test/lit/test_mlir_plugin.py
index 0be8d5cbb1..09d519c7ef 100644
--- a/frontend/test/lit/test_mlir_plugin.py
+++ b/frontend/test/lit/test_mlir_plugin.py
@@ -106,7 +106,7 @@ def test_pass_options():
     """Is the option in the generated MLIR?"""
 
     @qjit(target="mlir")
-    # CHECK: options = "an-option maxValue=1"
+    # CHECK: options = {"an-option" = true, "maxValue" = 1 : i64}
     @catalyst.passes.apply_pass("some-pass", "an-option", maxValue=1)
     @qml.qnode(qml.device("null.qubit", wires=1))
     def example():
diff --git a/frontend/test/pytest/test_jax_integration.py b/frontend/test/pytest/test_jax_integration.py
index 99b78f29b6..37e2d0762a 100644
--- a/frontend/test/pytest/test_jax_integration.py
+++ b/frontend/test/pytest/test_jax_integration.py
@@ -14,15 +14,19 @@
 
 """Test QJIT compatibility with JAX transformations such as jax.jit and jax.grad."""
 
+import textwrap
 from functools import partial
 
 import jax
 import jax.numpy as jnp
 import pennylane as qml
 import pytest
+from jax.interpreters.mlir import ir
 
 from catalyst import for_loop, measure, qjit
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
 from catalyst.jit import JAX_QJIT
+from catalyst.utils.exceptions import CompileError
 
 
 class TestJAXJIT:
@@ -490,5 +494,122 @@ def ansatz(i, x):
         jax.grad(circuit, argnums=0)(params, 3)
 
 
+ctx = ir.Context()
+loc = ir.Location.unknown(ctx)
+
+
+class TestJAXMLIRAttributeGetter:
+    """
+    Test catalyst.jax_extras.lowering.get_mlir_attribute_from_pyval
+    """
+
+    def test_bool_attr(self):
+        """
+        Test bool attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval(True)
+            assert isinstance(attr, ir.BoolAttr)
+            assert attr.value == True
+
+    def test_str_attr(self):
+        """
+        Test string attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval("hello catalyst!")
+            assert isinstance(attr, ir.StringAttr)
+            assert attr.value == "hello catalyst!"
+
+    @pytest.mark.parametrize("number", (37, -37))
+    def test_int_attr(self, number):
+        """
+        Test integer attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval(number)
+            assert isinstance(attr, ir.IntegerAttr)
+            assert attr.value == number
+
+    @pytest.mark.parametrize("number", (3.7, -3.7))
+    def test_float_attr(self, number):
+        """
+        Test float attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval(number)
+            assert isinstance(attr, ir.FloatAttr)
+            assert attr.value == number
+
+    @pytest.mark.parametrize("array", ([1, 2, 3], (4, 5, 6)))
+    def test_array_attr(self, array):
+        """
+        Test array attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval(array)
+            assert isinstance(attr, ir.ArrayAttr)
+            assert len(attr) == len(array)
+
+            for attr_val, py_val in zip(attr, array):
+                assert isinstance(attr_val, ir.IntegerAttr)
+                assert attr_val.value == py_val
+
+    def test_dict_attr(self):
+        """
+        Test dictionary attribute.
+        """
+        with ctx, loc:
+            attr = get_mlir_attribute_from_pyval(
+                {"device": "lightning.qubit", "wire_capacity": 100}
+            )
+            assert isinstance(attr, ir.DictAttr)
+
+            assert isinstance(attr["device"], ir.StringAttr)
+            assert attr["device"].value == "lightning.qubit"
+
+            assert isinstance(attr["wire_capacity"], ir.IntegerAttr)
+            assert attr["wire_capacity"].value == 100
+
+    def test_dict_attr_with_bad_keys(self):
+        """
+        Test dictionary attribute with non-string keys.
+        """
+        with pytest.raises(
+            CompileError, match="Dictionary keys for MLIR DictionaryAttr must be strings"
+        ):
+            with ctx, loc:
+                _ = get_mlir_attribute_from_pyval({37: 42})
+
+    def test_bad_type(self):
+        """
+        Test an error is correctly raised on a python type not convertible to mlir attribute.
+        """
+
+        # pylint: disable=missing-class-docstring
+        class Foo:
+            pass
+
+        with pytest.raises(CompileError, match="Cannot convert Python type"):
+            with ctx, loc:
+                _ = get_mlir_attribute_from_pyval(Foo())
+
+    def test_int_attr_overflow(self):
+        """
+        Test int attribute with overflow correctly raises error.
+        """
+        with pytest.raises(
+            CompileError,
+            match=textwrap.dedent(
+                """
+            Large interger attributes currently not supported in MLIR,
+            see https://github.com/llvm/llvm-project/issues/128072
+            """
+            ),
+        ):
+            with ctx, loc:
+                _ = get_mlir_attribute_from_pyval(2**100)
+
+
 if __name__ == "__main__":
     pytest.main(["-x", __file__])
diff --git a/frontend/test/pytest/test_mlir_plugin_interface.py b/frontend/test/pytest/test_mlir_plugin_interface.py
index 30f7c040b9..afde4c920f 100644
--- a/frontend/test/pytest/test_mlir_plugin_interface.py
+++ b/frontend/test/pytest/test_mlir_plugin_interface.py
@@ -19,6 +19,7 @@
 
 import pennylane as qml
 import pytest
+from jax.interpreters.mlir import ir
 
 import catalyst
 from catalyst import qjit
@@ -73,24 +74,26 @@ def test_get_options():
     """
     Test get_options from Pass
 
-      ApplyRegisteredPassOp expects options to be a single StringAttr
-      which follows the same format as the one used with mlir-opt.
-
-    https://mlir.llvm.org/docs/Dialects/Transform/#transformapply_registered_pass-transformapplyregisteredpassop
-
-      Options passed to a pass are specified via the syntax {option1=value1 option2=value2 ...},
-      i.e., use space-separated key=value pairs for each option.
-
-    https://mlir.llvm.org/docs/Tutorials/MlirOpt/#running-a-pass-with-options
-
-    However, experimentally we found that single-options also work without values.
+    ApplyRegisteredPassOp expects options to be a dictionary from strings to attributes.
+    See https://github.com/llvm/llvm-project/pull/143159
     """
-    assert catalyst.passes.Pass("example-pass", "single-option").get_options() == "single-option"
-    assert (
-        catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options()
-        == "an-option bn-option"
-    )
-    assert catalyst.passes.Pass("example-pass", option=True).get_options() == "option=True"
+    with ir.Context(), ir.Location.unknown():
+        options = catalyst.passes.Pass("example-pass", "single-option").get_options()
+        assert isinstance(options, dict)
+        assert isinstance(options["single-option"], ir.BoolAttr)
+        assert options["single-option"].value == True
+
+        options = catalyst.passes.Pass("example-pass", "an-option", "bn-option").get_options()
+        assert isinstance(options, dict)
+        assert isinstance(options["an-option"], ir.BoolAttr)
+        assert options["an-option"].value == True
+        assert isinstance(options["bn-option"], ir.BoolAttr)
+        assert options["bn-option"].value == True
+
+        options = catalyst.passes.Pass("example-pass", option=True).get_options()
+        assert isinstance(options, dict)
+        assert isinstance(options["option"], ir.BoolAttr)
+        assert options["option"].value == True
 
 
 @pytest.mark.skip(reason="xdsl not installed in ci cd yet")
diff --git a/mlir/Enzyme b/mlir/Enzyme
index db0181320d..8c1a596158 160000
--- a/mlir/Enzyme
+++ b/mlir/Enzyme
@@ -1 +1 @@
-Subproject commit db0181320d6e425ee963bd496ed0d8dbb615be18
+Subproject commit 8c1a596158f6194f10e8ffd56a1660a61c54337e
diff --git a/mlir/Makefile b/mlir/Makefile
index 5b3dc53f53..0a41527370 100644
--- a/mlir/Makefile
+++ b/mlir/Makefile
@@ -57,6 +57,14 @@ all: llvm mhlo enzyme dialects plugin
 .PHONY: llvm
 llvm:
 	@echo "build LLVM and MLIR enabling Python bindings"
+
+	# Patch mlir one shot bufferization segfault
+	# Remove patch after bug is resolved upstream
+	# https://github.com/llvm/llvm-project/issues/150441
+	@if cd llvm-project; git apply --check $(MK_DIR)/patches/llvm-bufferization-segfault.patch; then \
+		git apply $(MK_DIR)/patches/llvm-bufferization-segfault.patch; \
+	fi
+
 	cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \
 		-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
 		-DLLVM_BUILD_EXAMPLES=OFF \
@@ -97,6 +105,11 @@ mhlo:
 	@if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; then \
 		git apply $(MK_DIR)/patches/mhlo-add-back-necessary-passes.patch; \
 	fi
+
+	# Patch a MHLO bug with std::sort
+	@if cd mlir-hlo; git apply --check $(MK_DIR)/patches/mhlo-rename-sort.patch; then \
+		git apply $(MK_DIR)/patches/mhlo-rename-sort.patch; \
+	fi
 	cmake -G Ninja -S mlir-hlo -B $(MHLO_BUILD_DIR) \
 		-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
 		-DLLVM_ENABLE_ASSERTIONS=ON \
@@ -121,8 +134,8 @@ enzyme: PATCH_FILE := $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch
 enzyme:
 	@echo "build enzyme"
 	# Patch enzyme's dependency on nvidia fabs llvm intrinsics
-	@if patch --dry-run -p1 -N $(TARGET_FILE) $(PATCH_FILE) > /dev/null 2>&1; then \
-		patch -p1 $(TARGET_FILE) $(PATCH_FILE); \
+	@if cd Enzyme; git apply --check $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; then \
+		git apply $(MK_DIR)/patches/enzyme-nvvm-fabs-intrinsics.patch; \
 	fi
 	cmake -G Ninja -S Enzyme/enzyme -B $(ENZYME_BUILD_DIR) \
 		-DENZYME_STATIC_LIB=ON \
@@ -204,6 +217,11 @@ clean-dialects:
 clean-llvm:
 	@echo "clean llvm/mlir build files"
 	rm -rf $(LLVM_BUILD_DIR)
+	cd llvm-project; git clean -fd; git checkout .
+
+reset-llvm:
+	@echo "reset llvm git state to the commit tracked in .dep-versions without deleting llvm builds"
+	cd llvm-project; git clean -fd; git checkout .
 
 clean-mhlo:
 	@echo "clean HLO dialect build files"
@@ -213,6 +231,7 @@ clean-mhlo:
 clean-enzyme:
 	@echo "clean enzyme build files"
 	rm -rf $(ENZYME_BUILD_DIR)
+	cd Enzyme; git clean -fd; git checkout .
 
 clean-plugin:
 	@echo "clean plugin"
diff --git a/mlir/include/Catalyst/Transforms/AsyncUtils.h b/mlir/include/Catalyst/Transforms/AsyncUtils.h
index 7be6815235..98df1f3972 100644
--- a/mlir/include/Catalyst/Transforms/AsyncUtils.h
+++ b/mlir/include/Catalyst/Transforms/AsyncUtils.h
@@ -63,13 +63,13 @@ bool hasAbortInBlock(Block *block);
 bool hasPutsInBlock(Block *block);
 
 // Helper function for creating function declarations
-LLVM::LLVMFuncOp lookupOrCreatePersonality(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateAbort(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(ModuleOp);
-LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(ModuleOp);
-LLVM::LLVMFuncOp lookupOrCreateDropRef(ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp);
+LLVM::LLVMFuncOp lookupOrCreateDropRef(OpBuilder &b, ModuleOp);
 
 }; // namespace AsyncUtils
diff --git a/mlir/include/Gradient/Utils/DestinationPassingStyle.h b/mlir/include/Gradient/Utils/DestinationPassingStyle.h
index 920cfc45f3..7a6ae53415 100644
--- a/mlir/include/Gradient/Utils/DestinationPassingStyle.h
+++ b/mlir/include/Gradient/Utils/DestinationPassingStyle.h
@@ -17,5 +17,6 @@
 namespace catalyst {
 /// Convert every MemRef-typed return value in `callee` to writing to a new argument in
 /// destination-passing style.
-void convertToDestinationPassingStyle(mlir::func::FuncOp callee, mlir::OpBuilder &builder);
+llvm::LogicalResult convertToDestinationPassingStyle(mlir::func::FuncOp callee,
+                                                     mlir::OpBuilder &builder);
 } // namespace catalyst
diff --git a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
index 76cee871d4..ad3ec61a72 100644
--- a/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
+++ b/mlir/lib/Catalyst/Transforms/AsyncUtils.cpp
@@ -128,81 +128,84 @@ LLVM::LLVMFuncOp AsyncUtils::getCaller(LLVM::CallOp callOp)
     return callOp->getParentOfType();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreatePersonality(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     auto i32Ty = IntegerType::get(ctx, 32);
     bool isVarArg = true;
-    return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::personalityName, {}, i32Ty,
-                                        isVarArg)
+    return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::personalityName, {},
+                                        i32Ty, isVarArg)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAbort(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
+    return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::abortName, {}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitTokenName(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
     return mlir::LLVM::lookupOrCreateFn(
-               moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
+               b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitTokenName, {ptrTy}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateAwaitValueName(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
     return mlir::LLVM::lookupOrCreateFn(
-               moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
+               b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeAwaitValueName, {ptrTy}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateDropRef(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
     Type llvmInt64Type = IntegerType::get(moduleOp.getContext(), 64);
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
+    return mlir::LLVM::lookupOrCreateFn(b, moduleOp,
+                                        AsyncUtilsConstants::mlirAsyncRuntimeDropRefName,
                                         {ptrTy, llvmInt64Type}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(OpBuilder &b,
+                                                                         ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
     return mlir::LLVM::lookupOrCreateFn(
-               moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
+               b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetValueErrorName, {ptrTy}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(OpBuilder &b,
+                                                                         ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     Type ptrTy = LLVM::LLVMPointerType::get(moduleOp.getContext());
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
     return mlir::LLVM::lookupOrCreateFn(
-               moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
+               b, moduleOp, AsyncUtilsConstants::mlirAsyncRuntimeSetTokenErrorName, {ptrTy}, voidTy)
         .value();
 }
 
-LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(ModuleOp moduleOp)
+LLVM::LLVMFuncOp AsyncUtils::lookupOrCreateUnrecoverableError(OpBuilder &b, ModuleOp moduleOp)
 {
     MLIRContext *ctx = moduleOp.getContext();
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
-    return mlir::LLVM::lookupOrCreateFn(moduleOp, AsyncUtilsConstants::unrecoverableErrorName, {},
-                                        voidTy)
+    return mlir::LLVM::lookupOrCreateFn(b, moduleOp, AsyncUtilsConstants::unrecoverableErrorName,
+                                        {}, voidTy)
         .value();
 }
 
diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
index c5b93d1602..ebd559a188 100644
--- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -56,11 +56,12 @@ struct PrintOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto printOp = cast(op);
         if (printOp.getVal()) {
-            FailureOr source = getBuffer(rewriter, printOp.getVal(), options);
+            FailureOr source = getBuffer(rewriter, printOp.getVal(), options, state);
             if (failed(source)) {
                 return failure();
             }
@@ -116,7 +117,8 @@ struct CustomCallOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto customCallOp = cast(op);
 
@@ -124,7 +126,7 @@ struct CustomCallOpInterface
         SmallVector bufferArgs;
         ValueRange operands = customCallOp.getOperands();
         for (Value operand : operands) {
-            FailureOr opBuffer = getBuffer(rewriter, operand, options);
+            FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
             if (failed(opBuffer)) {
                 return failure();
             }
@@ -165,11 +167,11 @@ struct CustomCallOpInterface
             }
             auto options = bufferization::BufferizationOptions();
             FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
-                rewriter, op->getLoc(), result, options, false);
+                rewriter, op->getLoc(), result, options, state, false);
             MemRefType memrefType =
                 MemRefType::get(tensorType.getShape(), tensorType.getElementType());
             auto newBuffer =
-                rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
+                rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
             bufferArgs.push_back(newBuffer);
         }
 
@@ -207,7 +209,8 @@ struct CallbackOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto callbackOp = cast(op);
 
@@ -279,7 +282,8 @@ struct CallbackCallOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto callOp = cast(op);
 
@@ -292,7 +296,7 @@ struct CallbackCallOpInterface
         SmallVector newInputs;
         auto operands = callOp.getOperands();
         for (Value operand : operands) {
-            FailureOr opBuffer = getBuffer(rewriter, operand, options);
+            FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
             if (failed(opBuffer)) {
                 return failure();
             }
@@ -303,8 +307,8 @@ struct CallbackCallOpInterface
         auto loc = callOp->getLoc();
         SmallVector outmemrefs;
         for (auto result : results) {
-            FailureOr tensorAlloc =
-                bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false);
+            FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
+                rewriter, loc, result, options, state, false);
             if (failed(tensorAlloc)) {
                 return failure();
             }
@@ -314,8 +318,8 @@ struct CallbackCallOpInterface
             auto shape = tensorTy.getShape();
             auto elementTy = tensorTy.getElementType();
             auto memrefType = MemRefType::get(shape, elementTy);
-            auto toMemrefOp = rewriter.create(loc, memrefType, tensor);
-            auto memref = toMemrefOp.getResult();
+            auto toBufferOp = rewriter.create(loc, memrefType, tensor);
+            auto memref = toBufferOp.getResult();
             outmemrefs.push_back(memref);
             newInputs.push_back(memref);
         }
diff --git a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
index d516605b4c..160da0d118 100644
--- a/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
+++ b/mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
@@ -133,7 +133,7 @@ LogicalResult AddExceptionHandlingTransform::matchAndRewrite(LLVM::CallOp callOp
     auto moduleOp = callOp->getParentOfType();
     // Here, we are adding a reference to the personality declaration.
     // From the documentation: https://llvm.org/docs/ExceptionHandling.html#exception-tables
-    auto personality = AsyncUtils::lookupOrCreatePersonality(moduleOp);
+    auto personality = AsyncUtils::lookupOrCreatePersonality(rewriter, moduleOp);
 
     // We annotate the body of the function containing the callop to have a reference
     // to the personality.
@@ -294,7 +294,7 @@ RemoveAbortAndPutsInsertCallTransform::matchAndRewrite(LLVM::CallOp callOp,
     // Here, we are declaring an external function which is available in the Catalyst runtime.
     //     llvm.func @__catalyst__host__rt__unrecoverable_error()
     auto moduleOp = callOp->getParentOfType();
-    auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(moduleOp);
+    auto unrecoverableError = AsyncUtils::lookupOrCreateUnrecoverableError(rewriter, moduleOp);
 
     auto callee = maybeCallee.value();
     rewriter.modifyOpInPlace(callee, [&] { callee.setLinkage(LLVM::Linkage::Internal); });
@@ -516,8 +516,8 @@ LogicalResult LivenessAnalysisDropRef::matchAndRewrite(LLVM::CallOp sink,
     //     llvm.func @mlirAsyncRuntimeAwaitValue(!llvm.ptr)
     //     llvm.func @mlirAsyncRuntimeAwaitToken(!llvm.ptr)
     //     llvm.func @mlirAsyncRuntimeDropRef(!llvm.ptr, i64)
-    auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(moduleOp);
-    auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(moduleOp);
+    auto awaitFnDecl = AsyncUtils::lookupOrCreateAwaitTokenName(rewriter, moduleOp);
+    auto dropRefFnDecl = AsyncUtils::lookupOrCreateDropRef(rewriter, moduleOp);
 
     Type llvmInt64Type = IntegerType::get(sink->getContext(), 64);
     auto one = rewriter.getIntegerAttr(llvmInt64Type, 1);
@@ -871,9 +871,9 @@ void insertErrorCalls(std::vector tokens, std::vector values, Bloc
     auto moduleOp = landingPad->getParentOfType();
 
     LLVM::LLVMFuncOp setTokenError =
-        AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(moduleOp);
+        AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetTokenError(rewriter, moduleOp);
     LLVM::LLVMFuncOp setValueError =
-        AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(moduleOp);
+        AsyncUtils::lookupOrCreateMlirAsyncRuntimeSetValueError(rewriter, moduleOp);
     for (auto token : tokens) {
         insertCallToMlirAsyncRuntimeErrorFunction(token, setTokenError, failBlock, rewriter);
     }
@@ -918,11 +918,8 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase(context);
 
         GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
-        // TODO: Update to the following lines the next time we update llvm
-        // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
-        // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+        config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+        config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
 
         if (failed(applyPatternsGreedily(getOperation(), std::move(patterns1), config))) {
             signalPassFailure();
diff --git a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
index a3a6af1139..773621f132 100644
--- a/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
+++ b/mlir/lib/Catalyst/Transforms/GEPInboundsPatterns.cpp
@@ -29,11 +29,12 @@ struct GEPOpRewritePattern : public mlir::OpRewritePattern {
                                         mlir::PatternRewriter &rewriter) const override
     {
         auto defOp = op.getBase().getDefiningOp();
-        if (op.getInbounds() || (defOp && isa(defOp))) {
+        if (op.getNoWrapFlags() == LLVM::GEPNoWrapFlags::inbounds ||
+            (defOp && isa(defOp))) {
             return failure();
         }
         rewriter.startOpModification(op);
-        op.setInbounds(true);
+        op.setNoWrapFlags(LLVM::GEPNoWrapFlags::inbounds);
         rewriter.finalizeOpModification(op);
         return success();
     }
diff --git a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
index f540222b77..f115efb171 100644
--- a/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
+++ b/mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
@@ -380,11 +380,8 @@ struct AnnotateWithFullyQualifiedNamePass
     {
         MLIRContext *context = &getContext();
         GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
-        // TODO: Update to the following lines the next time we update llvm
-        // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
-        // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+        config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+        config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
 
         RewritePatternSet annotate(context);
         auto root = getOperation();
@@ -409,11 +406,8 @@ struct InlineNestedSymbolTablePass : PassWrapper(op);
                  }) != backwardSlice.end();
@@ -132,8 +133,8 @@ struct MemrefLoadTBAARewritePattern : public ConvertOpToLLVMPattern(
             loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, false,
             loadOp.getNontemporal());
@@ -170,8 +171,8 @@ struct MemrefStoreTBAARewritePattern : public ConvertOpToLLVMPattern(storeOp, adaptor.getValue(), dataPtr,
                                                              0, false, storeOp.getNontemporal());
 
diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
index d427d49386..5d884162fc 100644
--- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
+++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
@@ -51,7 +51,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe
     }
     return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
                                         type, rewriter.create(loc, glb),
-                                        ArrayRef{0, 0}, true);
+                                        ArrayRef{0, 0},
+                                        LLVM::GEPNoWrapFlags::inbounds);
 }
 
 enum NumericType : int8_t {
@@ -309,7 +310,8 @@ Value EncodeDataMemRef(Location loc, PatternRewriter &rewriter, MemRefType memre
     MemRefDescriptor desc = MemRefDescriptor(memrefLlvm);
     Value c0 = rewriter.create(loc, rewriter.getI64IntegerAttr(0));
     Value data = rewriter.create(loc, ptr, memrefType.getElementType(),
-                                              desc.alignedPtr(rewriter, loc), c0, true);
+                                              desc.alignedPtr(rewriter, loc), c0,
+                                              LLVM::GEPNoWrapFlags::inbounds);
     memref = rewriter.create(loc, memref, data, 1);
 
     // Dtype
@@ -335,7 +337,8 @@ struct CustomCallOpPattern : public OpConversionPattern {
         rewriter.setInsertionPointToStart(mod.getBody());
 
         LLVM::LLVMFuncOp customCallFnOp =
-            mlir::LLVM::lookupOrCreateFn(mod, op.getCallTargetName(), {/*args=*/ptr, /*rets=*/ptr},
+            mlir::LLVM::lookupOrCreateFn(rewriter, mod, op.getCallTargetName(),
+                                         {/*args=*/ptr, /*rets=*/ptr},
                                          /*ret_type=*/voidType)
                 .value();
         customCallFnOp.setPrivate();
@@ -467,7 +470,7 @@ struct DefineCallbackOpPattern : public OpConversionPattern {
         ModuleOp mod = op->getParentOfType();
         auto typeConverter = getTypeConverter();
         LLVM::LLVMFuncOp customCallFnOp =
-            mlir::LLVM::lookupOrCreateFn(mod, "__catalyst_inactive_callback",
+            mlir::LLVM::lookupOrCreateFn(rewriter, mod, "__catalyst_inactive_callback",
                                          {/*args=*/i64, i64, i64},
                                          /*ret_type=*/voidType, isVarArg)
                 .value();
diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
index d1980a658e..54cbc85aac 100644
--- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -177,7 +177,8 @@ struct AdjointOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto adjointOp = cast(op);
         Location loc = adjointOp.getLoc();
@@ -207,7 +208,7 @@ struct AdjointOpInterface
         ValueRange operands = adjointOp.getArgs();
         for (Value operand : operands) {
             if (isa(operand.getType())) {
-                FailureOr opBuffer = getBuffer(rewriter, operand, options);
+                FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
                 if (failed(opBuffer)) {
                     return failure();
                 }
@@ -276,7 +277,8 @@ struct BackpropOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto backpropOp = cast(op);
         Location loc = backpropOp.getLoc();
@@ -293,7 +295,7 @@ struct BackpropOpInterface
         ValueRange operands = backpropOp.getArgs();
         for (Value operand : operands) {
             if (isa(operand.getType())) {
-                FailureOr opBuffer = getBuffer(rewriter, operand, options);
+                FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
                 if (failed(opBuffer)) {
                     return failure();
                 }
@@ -333,7 +335,7 @@ struct BackpropOpInterface
         ValueRange cotangents = backpropOp.getCotangents();
         SmallVector bufferCotangents;
         for (Value operand : cotangents) {
-            FailureOr opBuffer = getBuffer(rewriter, operand, options);
+            FailureOr opBuffer = getBuffer(rewriter, operand, options, state);
             if (failed(opBuffer)) {
                 return failure();
             }
@@ -402,6 +404,7 @@ struct ForwardOpInterface
 
     FailureOr getBufferType(Operation *op, Value value,
                                             const bufferization::BufferizationOptions &options,
+                                            const bufferization::BufferizationState &state,
                                             SmallVector &invocationStack) const
     {
         // The getBufferType() method is called on either BlockArguments or OpResults.
@@ -426,11 +429,12 @@ struct ForwardOpInterface
             return getBufferizedFunctionArgType(forwardOp, bbArg.getArgNumber(), options);
         }
 
-        return bufferization::detail::defaultGetBufferType(value, options, invocationStack);
+        return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack);
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto forwardOp = cast(op);
         FunctionType funcType = forwardOp.getFunctionType();
@@ -451,7 +455,7 @@ struct ForwardOpInterface
 
         // 1. Bufferize every block.
         for (Block &block : forwardOp.getBody()) {
-            if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) {
+            if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) {
                 return failure();
             }
         }
@@ -471,11 +475,12 @@ struct ForwardOpInterface
 
             // Note: If `inferFunctionResultLayout = true`, cast are later folded
             // away.
-            BaseMemRefType resultType = options.unknownTypeConverterFn(
-                returnVal, *options.defaultMemorySpaceFn(tensorType), options);
-            Value toMemrefOp =
-                rewriter.create(loc, resultType, returnVal);
-            returnValues.push_back(toMemrefOp);
+            BaseMemRefType resultType =
+                options.unknownTypeConverterFn(cast(returnVal.getType()),
+                                               *options.defaultMemorySpaceFn(tensorType), options);
+            Value toBufferOp =
+                rewriter.create(loc, resultType, returnVal);
+            returnValues.push_back(toBufferOp);
         }
 
         // 3. Rewrite the terminator.
@@ -523,6 +528,7 @@ struct ReverseOpInterface
 
     FailureOr getBufferType(Operation *op, Value value,
                                             const bufferization::BufferizationOptions &options,
+                                            const bufferization::BufferizationState &state,
                                             SmallVector &invocationStack) const
     {
         // See comment on the getBufferType() method on forward op.
@@ -534,11 +540,12 @@ struct ReverseOpInterface
             return getBufferizedFunctionArgType(reverseOp, bbArg.getArgNumber(), options);
         }
 
-        return bufferization::detail::defaultGetBufferType(value, options, invocationStack);
+        return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack);
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto reverseOp = cast(op);
         FunctionType funcType = reverseOp.getFunctionType();
@@ -559,7 +566,7 @@ struct ReverseOpInterface
 
         // 1. Bufferize every block.
         for (Block &block : reverseOp.getBody())
-            if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options)))
+            if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
                 return failure();
 
         // 2. For each result, keep track of which inplace argument it reuses.
@@ -577,11 +584,12 @@ struct ReverseOpInterface
 
             // Note: If `inferFunctionResultLayout = true`, cast are later folded
             // away.
-            BaseMemRefType resultType = options.unknownTypeConverterFn(
-                returnVal, *options.defaultMemorySpaceFn(tensorType), options);
-            Value toMemrefOp =
-                rewriter.create(loc, resultType, returnVal);
-            returnValues.push_back(toMemrefOp);
+            BaseMemRefType resultType =
+                options.unknownTypeConverterFn(cast(returnVal.getType()),
+                                               *options.defaultMemorySpaceFn(tensorType), options);
+            Value toBufferOp =
+                rewriter.create(loc, resultType, returnVal);
+            returnValues.push_back(toBufferOp);
         }
 
         // 3. Rewrite the terminator.
diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
index e3a472945e..4e2c4f188e 100644
--- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
@@ -49,8 +49,8 @@ using namespace catalyst::gradient;
 
 namespace catalyst {
 namespace gradient {
-void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
-                        RewriterBase &rewriter, Location loc, bool volatileArgs = false)
+LogicalResult wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
+                                 RewriterBase &rewriter, Location loc, bool volatileArgs = false)
 {
     MLIRContext *ctx = rewriter.getContext();
     auto ptrType = LLVM::LLVMPointerType::get(ctx);
@@ -59,7 +59,9 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
     for (const auto [idx, argType] : llvm::enumerate(func.getArgumentTypes())) {
         if (auto memrefType = dyn_cast(argType)) {
             BlockArgument memrefArg = func.getArgument(idx);
-            func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc);
+            if (failed(func.insertArgument(idx, ptrType, DictionaryAttr::get(ctx), loc))) {
+                return failure();
+            }
             Value wrappedMemref = func.getArgument(idx);
             Type structType = typeConverter->convertType(memrefType);
 
@@ -78,9 +80,12 @@ void wrapMemRefArgsFunc(func::FuncOp func, const TypeConverter *typeConverter,
                 rewriter.create(loc, argType, replacedMemref)
                     .getResult(0);
             memrefArg.replaceAllUsesWith(replacedMemref);
-            func.eraseArgument(memrefArg.getArgNumber());
+            if (failed(func.eraseArgument(memrefArg.getArgNumber()))) {
+                return failure();
+            }
         }
     }
+    return success();
 }
 
 void wrapMemRefArgsCallsites(func::FuncOp func, const TypeConverter *typeConverter,
@@ -171,16 +176,19 @@ LLVM::GlobalOp insertEnzymeCustomGradient(OpBuilder &builder, ModuleOp moduleOp,
 /// functions where MemRefs are passed via wrapped pointers (!llvm.ptr)
 /// rather than having their fields unpacked. This function automatically transforms MemRef
 /// arguments of a function to wrapped pointers.
-void wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter, RewriterBase &rewriter,
-                    Location loc, bool volatileArgs = false)
+LogicalResult wrapMemRefArgs(func::FuncOp func, const TypeConverter *typeConverter,
+                             RewriterBase &rewriter, Location loc, bool volatileArgs = false)
 {
     if (llvm::none_of(func.getArgumentTypes(),
                       [](Type argType) { return isa(argType); })) {
         // The memref arguments are already wrapped
-        return;
+        return success();
+    }
+    if (failed(wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs))) {
+        return failure();
     }
-    wrapMemRefArgsFunc(func, typeConverter, rewriter, loc, volatileArgs);
     wrapMemRefArgsCallsites(func, typeConverter, rewriter, loc, volatileArgs);
+    return success();
 }
 } // namespace gradient
 } // namespace catalyst
@@ -290,7 +298,9 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
             SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr());
         assert(callee && "Expected a valid callee of type func.func");
 
-        catalyst::convertToDestinationPassingStyle(callee, rewriter);
+        if (failed(catalyst::convertToDestinationPassingStyle(callee, rewriter))) {
+            return failure();
+        }
         SymbolTableCollection symbolTable;
         catalyst::traverseCallGraph(callee, &symbolTable, [&](func::FuncOp func) {
             // Register custom gradients of quantum functions
@@ -304,9 +314,11 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
                 if (!func->hasAttr("unwrapped_type")) {
                     func->setAttr("unwrapped_type", TypeAttr::get(func.getFunctionType()));
                 }
-                catalyst::convertToDestinationPassingStyle(func, rewriter);
-
-                wrapMemRefArgs(func, getTypeConverter(), rewriter, loc, /*volatileArgs=*/true);
+                LogicalResult dpsr = catalyst::convertToDestinationPassingStyle(func, rewriter);
+                assert(dpsr.succeeded() && "failed to rewrite backpropOp to destination style");
+                LogicalResult wmar = wrapMemRefArgs(func, getTypeConverter(), rewriter, loc,
+                                                    /*volatileArgs=*/true);
+                assert(wmar.succeeded() && "failed to wrap backpropOp's memref args");
 
                 func::FuncOp augFwd = genAugmentedForward(func, rewriter);
                 func::FuncOp customQGrad =
@@ -318,10 +330,10 @@ struct BackpropOpPattern : public ConvertOpToLLVMPattern {
 
         LowerToLLVMOptions options = getTypeConverter()->getOptions();
         if (options.useGenericFunctions) {
-            LLVM::LLVMFuncOp allocFn =
-                LLVM::lookupOrCreateGenericAllocFn(moduleOp, getTypeConverter()->getIndexType())
-                    .value();
-            LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(moduleOp).value();
+            LLVM::LLVMFuncOp allocFn = LLVM::lookupOrCreateGenericAllocFn(
+                                           rewriter, moduleOp, getTypeConverter()->getIndexType())
+                                           .value();
+            LLVM::LLVMFuncOp freeFn = LLVM::lookupOrCreateGenericFreeFn(rewriter, moduleOp).value();
 
             // Register the previous functions as llvm globals (for Enzyme)
             // With the following piece of metadata, shadow memory is allocated with
@@ -862,7 +874,10 @@ struct ForwardOpPattern : public ConvertOpToLLVMPattern {
         func->setAttr("passthrough", ArrayAttr::get(ctx, passthrough));
 
         rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
-        catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc());
+        if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter,
+                                                          op.getLoc()))) {
+            return failure();
+        }
         rewriter.eraseOp(op);
         return success();
     }
@@ -884,16 +899,12 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
         auto params = op.getArguments();
 
         for (size_t i = 0; i < argc * 2; i++) {
-            bool isDup = (i % 2) != 0;
-            Value val = params[i];
-            isDup ? differentials.push_back(val) : inputs.push_back(val);
+            fillValueAndShadowWithDedup(i, params, differentials, inputs);
         }
 
         auto upperLimit = (argc * 2) + (resc * 2);
         for (size_t i = argc * 2; i < upperLimit; i++) {
-            bool isDup = (i % 2) != 0;
-            Value val = params[i];
-            isDup ? cotangents.push_back(val) : outputs.push_back(val);
+            fillValueAndShadowWithDedup(i, params, cotangents, outputs);
         }
 
         auto tapeCount = op.getTape();
@@ -903,16 +914,7 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
         }
 
         SmallVector newFuncInputTys;
-
-        for (auto [in, diff] : llvm::zip(inputs, differentials)) {
-            newFuncInputTys.push_back(in.getType());
-            newFuncInputTys.push_back(diff.getType());
-        }
-
-        for (auto [out, cotan] : llvm::zip(outputs, cotangents)) {
-            newFuncInputTys.push_back(out.getType());
-            newFuncInputTys.push_back(cotan.getType());
-        }
+        getNewFuncInputTys(inputs, outputs, differentials, cotangents, newFuncInputTys);
 
         SmallVector tapeStructs;
         auto converter = getTypeConverter();
@@ -986,11 +988,40 @@ struct ReverseOpPattern : public ConvertOpToLLVMPattern {
         Block &firstBlock = func.getRegion().getBlocks().front();
         Block &lastBlock = func.getRegion().getBlocks().back();
         rewriter.mergeBlocks(&lastBlock, &firstBlock);
-        catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter, op.getLoc());
+        if (failed(catalyst::gradient::wrapMemRefArgsFunc(func, typeConverter, rewriter,
+                                                          op.getLoc()))) {
+            return failure();
+        }
 
         rewriter.eraseOp(op);
         return success();
     }
+
+  private:
+    static void getNewFuncInputTys(const SmallVector &inputs,
+                                   const SmallVector &outputs,
+                                   const SmallVector &differentials,
+                                   const SmallVector &cotangents,
+                                   SmallVector &newFuncInputTys)
+    {
+        for (auto [in, diff] : llvm::zip(inputs, differentials)) {
+            newFuncInputTys.push_back(in.getType());
+            newFuncInputTys.push_back(diff.getType());
+        }
+
+        for (auto [out, cotan] : llvm::zip(outputs, cotangents)) {
+            newFuncInputTys.push_back(out.getType());
+            newFuncInputTys.push_back(cotan.getType());
+        }
+    }
+
+    static void fillValueAndShadowWithDedup(size_t i, ValueRange params, SmallVector &values,
+                                            SmallVector &shadows)
+    {
+        bool isDup = (i % 2) != 0;
+        Value val = params[i];
+        isDup ? shadows.push_back(val) : values.push_back(val);
+    }
 };
 
 struct ReturnOpPattern : public ConvertOpToLLVMPattern {
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
index 464ab29089..5c24252a1e 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/ClassicalJacobian.cpp
@@ -148,7 +148,8 @@ func::FuncOp genSplitPreprocessed(PatternRewriter &rewriter, Location loc, func:
         PatternRewriter::InsertionGuard insertGuard(rewriter);
         rewriter.setInsertionPointToStart(&splitFn.getBody().front());
         Value paramsBuffer = rewriter.create(loc, paramsBufferType, paramCount);
-        Value paramsTensor = rewriter.create(loc, paramsBuffer, true);
+        Value paramsTensor = rewriter.create(
+            loc, memref::getTensorTypeFromMemRefType(paramsBuffer.getType()), paramsBuffer, true);
 
         qnodeQuantumArgs.push_back(paramsTensor);
         MemRefType paramsProcessedType = MemRefType::get({}, rewriter.getIndexType());
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
index 66fbcee25d..853ecd1e82 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 
 #include "Gradient/Utils/DifferentialQNode.h"
@@ -163,13 +164,14 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l
                     auto tensorTy = diffArg.getType();
                     auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout(
                         cast(tensorTy));
-                    auto toMemrefOp =
-                        rewriter.create(loc, memrefTy, diffArg);
+                    auto toBufferOp =
+                        rewriter.create(loc, memrefTy, diffArg);
 
-                    auto cloneOp = rewriter.create(loc, toMemrefOp);
+                    auto cloneOp = rewriter.create(loc, toBufferOp);
 
-                    auto toTensorOp =
-                        rewriter.create(loc, cloneOp, true);
+                    auto toTensorOp = rewriter.create(
+                        loc, memref::getTensorTypeFromMemRefType(cloneOp.getOutput().getType()),
+                        cloneOp, true);
 
                     auto diffArgCopy = toTensorOp.getResult();
 
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
index 5db5c4a149..75ad0a61a0 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/HybridGradient.cpp
@@ -86,9 +86,9 @@ void initializeCotangents(TypeRange primalResultTypes, unsigned activeResult, Va
                             : activeResultType);
 
     Value zero = builder.create(
-        loc, APFloat(elementType.getFloatSemantics(), 0), elementType);
-    Value one = builder.create(
-        loc, APFloat(elementType.getFloatSemantics(), 1), elementType);
+        loc, elementType, APFloat(elementType.getFloatSemantics(), 0));
+    Value one = builder.create(loc, elementType,
+                                                       APFloat(elementType.getFloatSemantics(), 1));
 
     Value zeroTensor;
     if (auto activeResultTensor = dyn_cast(activeResultType)) {
@@ -397,7 +397,7 @@ static func::FuncOp genFullGradFunction(PatternRewriter &rewriter, Location loc,
                 }
                 else {
                     jacobians.push_back(rewriter.create(
-                        loc, APFloat(0.0), cast(jacobianType)));
+                        loc, cast(jacobianType), APFloat(0.0)));
                 }
             }
 
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
index ce91f8ec80..7227a3d35a 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_FunctionShifting.cpp
@@ -35,7 +35,7 @@ static Value genSelectiveShift(PatternRewriter &rewriter, Location loc, Value pa
     }
 
     // Make sure all active iteration variables match the selectors.
-    Value shiftCondition = rewriter.create(loc, true, 1);
+    Value shiftCondition = rewriter.create(loc, 1, true);
     for (auto &[iteration, selector] : selectors) {
         Value iterationMatch =
             rewriter.create(loc, arith::CmpIPredicate::eq, iteration, selector);
diff --git a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
index ff6d172908..d5cb0c117e 100644
--- a/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
+++ b/mlir/lib/Gradient/Transforms/GradMethods/PS_QuantumGradient.cpp
@@ -59,7 +59,8 @@ static std::vector computePartialDerivative(PatternRewriter &rewriter, Lo
 {
     constexpr double shift = llvm::numbers::pi / 2;
     ShapedType shiftVectorType = RankedTensorType::get({numShifts}, rewriter.getF64Type());
-    Value selectorVector = rewriter.create(loc, selectorBuffer, true);
+    Value selectorVector = rewriter.create(
+        loc, memref::getTensorTypeFromMemRefType(selectorBuffer.getType()), selectorBuffer, true);
 
     // Define the shift vectors (pos/neg) as sparse tensor constants.
     DenseElementsAttr nonZeroIndices = rewriter.getI64TensorAttr(currentShift);
@@ -285,8 +286,9 @@ func::FuncOp ParameterShiftLowering::genQGradFunction(PatternRewriter &rewriter,
                 std::vector gradientTensors;
                 gradientTensors.reserve(gradResTypes.size());
                 for (Value gradientBuffer : gradientBuffers) {
-                    gradientTensors.push_back(
-                        rewriter.create(loc, gradientBuffer, true));
+                    gradientTensors.push_back(rewriter.create(
+                        loc, memref::getTensorTypeFromMemRefType(gradientBuffer.getType()),
+                        gradientBuffer, true));
                 }
                 op->setOperands(gradientTensors);
             }
diff --git a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
index 70f623fc29..d34dd89974 100644
--- a/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/PostprocessingPatterns.cpp
@@ -108,7 +108,9 @@ struct PostprocessForwardOp : public OpRewritePattern {
             // Insert new argIn in an interleaving way.
             size_t idx = 0;
             for (auto ty : newArgInTypes) {
-                op.insertArgument(2 * idx + 1, ty, {}, op.getLoc());
+                if (failed(op.insertArgument(2 * idx + 1, ty, {}, op.getLoc()))) {
+                    return failure();
+                }
                 idx++;
             }
             // Append newArgRes.
@@ -117,8 +119,11 @@ struct PostprocessForwardOp : public OpRewritePattern {
                                              /*values=*/op.getNumArguments());
             SmallVector argAttrs{appendingSize};
             SmallVector argLocs{appendingSize, op.getLoc()};
-            op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs);
+            if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) {
+                return failure();
+            }
             op.setFunctionType(forwardTy);
+            return success();
         });
 
         op.walk([&](ReturnOp returnOp) {
@@ -195,7 +200,9 @@ struct PostprocessReverseOp : public OpRewritePattern {
             // Insert new argIn in an interleaving way.
             size_t idx = 0;
             for (auto ty : newArgInTypes) {
-                op.insertArgument(2 * idx, ty, {}, op.getLoc());
+                if (failed(op.insertArgument(2 * idx, ty, {}, op.getLoc()))) {
+                    return failure();
+                }
                 idx++;
             }
             // Append newArgRes.
@@ -204,8 +211,11 @@ struct PostprocessReverseOp : public OpRewritePattern {
                                              /*values=*/0);
             SmallVector argAttrs{appendingSize};
             SmallVector argLocs{appendingSize, op.getLoc()};
-            op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs);
+            if (failed(op.insertArguments(argIndices, newArgResTypes, argAttrs, argLocs))) {
+                return failure();
+            }
             op.setFunctionType(reverseTy);
+            return success();
         });
 
         op.walk([&](ReturnOp returnOp) {
diff --git a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
index ef6b572583..08c082a316 100644
--- a/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
+++ b/mlir/lib/Gradient/Utils/DestinationPassingStyle.cpp
@@ -19,11 +19,11 @@
 
 using namespace mlir;
 
-void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder)
+LogicalResult catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &builder)
 {
     if (callee.getNumResults() == 0) {
         // Callee is already in destination-passing style
-        return;
+        return success();
     }
 
     MLIRContext *ctx = callee.getContext();
@@ -48,7 +48,7 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
     if (callee.isDeclaration()) {
         // If the function does not have a body, we are done after modifying the function type.
         callee.setFunctionType(dpsFunctionType);
-        return;
+        return success();
     }
 
     // Insert the new output arguments to the function.
@@ -60,7 +60,9 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
 
     // insertArguments modifies the function type, so we need to update the function type *after*
     // inserting the arguments.
-    callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs);
+    if (failed(callee.insertArguments(argIndices, memRefReturnTypes, argAttrs, argLocs))) {
+        return failure();
+    }
     callee.setFunctionType(dpsFunctionType);
 
     // Update return sites to copy over the memref that would have been returned to the output.
@@ -83,4 +85,6 @@ void catalyst::convertToDestinationPassingStyle(func::FuncOp callee, OpBuilder &
         }
         returnOp.getOperandsMutable().assign(nonMemRefReturns);
     });
+
+    return success();
 }
diff --git a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
index 5768982de5..4d1164c140 100644
--- a/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
+++ b/mlir/lib/Gradient/Utils/EinsumLinalgGeneric.cpp
@@ -44,7 +44,7 @@ Value buildTensorLinalgGeneric(OpBuilder &builder, Location loc, ValueRange oper
     // Initialize the result tensor
     FloatType elementType = cast(resultType.getElementType());
     Value zero = builder.create(
-        loc, APFloat::getZero(elementType.getFloatSemantics()), elementType);
+        loc, elementType, APFloat::getZero(elementType.getFloatSemantics()));
     Value result =
         builder.create(loc, resultType.getShape(), resultType.getElementType());
     result = builder.create(loc, zero, result).getResult(0);
diff --git a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
index 4a0312478e..9498f306e8 100644
--- a/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Quantum/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -58,15 +58,16 @@ struct QubitUnitaryOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto qubitUnitaryOp = cast(op);
         Location loc = op->getLoc();
         auto tensorType = cast(qubitUnitaryOp.getMatrix().getType());
         MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-        auto toMemrefOp =
-            rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
-        auto memref = toMemrefOp.getResult();
+        auto toBufferOp =
+            rewriter.create(loc, memrefType, qubitUnitaryOp.getMatrix());
+        auto memref = toBufferOp.getResult();
         bufferization::replaceOpWithNewBufferizedOp(
             rewriter, op, qubitUnitaryOp.getOutQubits().getTypes(),
             qubitUnitaryOp.getOutCtrlQubits().getTypes(), memref, qubitUnitaryOp.getInQubits(),
@@ -101,15 +102,16 @@ struct HermitianOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto hermitianOp = cast(op);
         Location loc = op->getLoc();
         auto tensorType = cast(hermitianOp.getMatrix().getType());
         MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-        auto toMemrefOp =
-            rewriter.create(loc, memrefType, hermitianOp.getMatrix());
-        auto memref = toMemrefOp.getResult();
+        auto toBufferOp =
+            rewriter.create(loc, memrefType, hermitianOp.getMatrix());
+        auto memref = toBufferOp.getResult();
         auto newHermitianOp = rewriter.create(loc, hermitianOp.getType(), memref,
                                                            hermitianOp.getQubits());
         bufferization::replaceOpWithBufferizedValues(rewriter, op, newHermitianOp.getObs());
@@ -143,15 +145,16 @@ struct HamiltonianOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto hamiltonianOp = cast(op);
         Location loc = op->getLoc();
         auto tensorType = cast(hamiltonianOp.getCoeffs().getType());
         MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-        auto toMemrefOp =
-            rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
-        auto memref = toMemrefOp.getResult();
+        auto toBufferOp =
+            rewriter.create(loc, memrefType, hamiltonianOp.getCoeffs());
+        auto memref = toBufferOp.getResult();
         auto newHamiltonianOp = rewriter.create(loc, hamiltonianOp.getType(), memref,
                                                                hamiltonianOp.getTerms());
         bufferization::replaceOpWithBufferizedValues(rewriter, op, newHamiltonianOp.getObs());
@@ -187,7 +190,8 @@ struct SampleOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto sampleOp = cast(op);
         Location loc = op->getLoc();
@@ -237,7 +241,8 @@ struct CountsOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto countsOp = cast(op);
         Location loc = op->getLoc();
@@ -297,7 +302,8 @@ struct ProbsOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto probsOp = cast(op);
         Location loc = op->getLoc();
@@ -350,7 +356,8 @@ struct StateOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto stateOp = cast(op);
         Location loc = op->getLoc();
@@ -401,16 +408,17 @@ struct SetStateOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto setStateOp = cast(op);
         Location loc = op->getLoc();
         auto tensorType = cast(setStateOp.getInState().getType());
         MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
 
-        auto toMemrefOp =
-            rewriter.create(loc, memrefType, setStateOp.getInState());
-        auto memref = toMemrefOp.getResult();
+        auto toBufferOp =
+            rewriter.create(loc, memrefType, setStateOp.getInState());
+        auto memref = toBufferOp.getResult();
         auto newSetStateOp = rewriter.create(loc, setStateOp.getOutQubits().getTypes(),
                                                          memref, setStateOp.getInQubits());
         bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
@@ -443,16 +451,17 @@ struct SetBasisStateOpInterface
     }
 
     LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
-                            const bufferization::BufferizationOptions &options) const
+                            const bufferization::BufferizationOptions &options,
+                            bufferization::BufferizationState &state) const
     {
         auto setBasisStateOp = cast(op);
         Location loc = op->getLoc();
         auto tensorType = cast(setBasisStateOp.getBasisState().getType());
         MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
 
-        auto toMemrefOp = rewriter.create(
+        auto toBufferOp = rewriter.create(
             loc, memrefType, setBasisStateOp.getBasisState());
-        auto memref = toMemrefOp.getResult();
+        auto memref = toBufferOp.getResult();
         auto newSetStateOp = rewriter.create(
             loc, setBasisStateOp.getOutQubits().getTypes(), memref, setBasisStateOp.getInQubits());
         bufferization::replaceOpWithBufferizedValues(rewriter, op, newSetStateOp.getOutQubits());
diff --git a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
index eaf30e2829..8dbf401c46 100644
--- a/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Quantum/Transforms/ConversionPatterns.cpp
@@ -44,7 +44,8 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe
     }
     return rewriter.create(loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
                                         type, rewriter.create(loc, glb),
-                                        ArrayRef{0, 0}, true);
+                                        ArrayRef{0, 0},
+                                        LLVM::GEPNoWrapFlags::inbounds);
 }
 
 /**
@@ -80,13 +81,17 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter
     auto structType = LLVM::LLVMStructType::getLiteral(ctx, {boolType, sizeType, ptrType, ptrType});
     auto modifiersPtr = catalyst::getStaticAlloca(loc, rewriter, structType, 1).getResult();
     auto adjointPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
-                                                   llvm::ArrayRef{0, 0}, true);
+                                                   llvm::ArrayRef{0, 0},
+                                                   LLVM::GEPNoWrapFlags::inbounds);
     auto numControlledPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
-                                                         llvm::ArrayRef{0, 1}, true);
-    auto controlledWiresPtr = rewriter.create(
-        loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 2}, true);
-    auto controlledValuesPtr = rewriter.create(
-        loc, ptrType, structType, modifiersPtr, llvm::ArrayRef{0, 3}, true);
+                                                         llvm::ArrayRef{0, 1},
+                                                         LLVM::GEPNoWrapFlags::inbounds);
+    auto controlledWiresPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
+                                                           llvm::ArrayRef{0, 2},
+                                                           LLVM::GEPNoWrapFlags::inbounds);
+    auto controlledValuesPtr = rewriter.create(loc, ptrType, structType, modifiersPtr,
+                                                            llvm::ArrayRef{0, 3},
+                                                            LLVM::GEPNoWrapFlags::inbounds);
 
     Value ctrlPtr = nullPtr;
     Value valuePtr = nullPtr;
@@ -98,13 +103,15 @@ Value getModifiersPtr(Location loc, RewriterBase &rewriter, const TypeConverter
         for (int i = 0; static_cast(i) < controlledQubits.size(); i++) {
             {
                 auto itemPtr = rewriter.create(loc, ptrType, ptrType, ctrlPtr,
-                                                            llvm::ArrayRef{i}, true);
+                                                            llvm::ArrayRef{i},
+                                                            LLVM::GEPNoWrapFlags::inbounds);
                 auto qubit = controlledQubits[i];
                 rewriter.create(loc, qubit, itemPtr);
             }
             {
                 auto itemPtr = rewriter.create(loc, ptrType, boolType, valuePtr,
-                                                            llvm::ArrayRef{i}, true);
+                                                            llvm::ArrayRef{i},
+                                                            LLVM::GEPNoWrapFlags::inbounds);
                 auto value = controlledValues[i];
                 rewriter.create(loc, value, itemPtr);
             }
@@ -1012,7 +1019,7 @@ struct SetStateOpPattern : public OpConversionPattern {
         auto voidTy = LLVM::LLVMVoidType::get(ctx);
         auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
         ModuleOp moduleOp = op->getParentOfType();
-        auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetState",
+        auto func = mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetState",
                                                  {ptrTy, i64}, voidTy, isVarArg)
                         .value();
 
@@ -1052,9 +1059,10 @@ struct SetBasisStateOpPattern : public OpConversionPattern {
         auto voidTy = LLVM::LLVMVoidType::get(ctx);
         auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
         ModuleOp moduleOp = op->getParentOfType();
-        auto func = mlir::LLVM::lookupOrCreateFn(moduleOp, "__catalyst__qis__SetBasisState",
-                                                 {ptrTy, i64}, voidTy, isVarArg)
-                        .value();
+        auto func =
+            mlir::LLVM::lookupOrCreateFn(rewriter, moduleOp, "__catalyst__qis__SetBasisState",
+                                         {ptrTy, i64}, voidTy, isVarArg)
+                .value();
 
         SmallVector args;
 
diff --git a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
index 4c91fbf67c..f4b8a2f9ca 100644
--- a/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
+++ b/mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
@@ -214,13 +214,9 @@ struct EmitCatalystPyInterfacePass
         patterns.add(context);
 
         GreedyRewriteConfig config;
-        config.strictMode = GreedyRewriteStrictness::ExistingOps;
-        config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
-        config.maxIterations = 1;
-        // TODO: Update to the following lines the next time we update llvm
-        // config.setStrictness(GreedyRewriteStrictness::ExistingOps);
-        // config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
-        // config.setMaxIterations(1);
+        config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+        config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled);
+        config.setMaxIterations(1);
 
         auto op = getOperation();
         SmallVector targets;
diff --git a/mlir/llvm-project b/mlir/llvm-project
index 179d30f8c3..f8cb7987c6 160000
--- a/mlir/llvm-project
+++ b/mlir/llvm-project
@@ -1 +1 @@
-Subproject commit 179d30f8c3fddd3c85056fd2b8e877a4a8513158
+Subproject commit f8cb7987c64dcffb72414a40560055cb717dbf74
diff --git a/mlir/mlir-hlo b/mlir/mlir-hlo
index 617a9361d1..1dd2e71331 160000
--- a/mlir/mlir-hlo
+++ b/mlir/mlir-hlo
@@ -1 +1 @@
-Subproject commit 617a9361d186199480c080c9e8c474a5e30c22d1
+Subproject commit 1dd2e71331014ae0373f6bf900ce6be393357190
diff --git a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
index 9f80c60a75..8746e54c73 100644
--- a/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
+++ b/mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
@@ -1,8 +1,8 @@
 diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-index 85050315..414318eb 100644
+index 7c234dd4..846f68b4 100644
 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
 +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
-@@ -3940,14 +3940,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
+@@ -3942,14 +3942,6 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
    case Intrinsic::nearbyint:
    case Intrinsic::round:
    case Intrinsic::sqrt:
diff --git a/mlir/patches/llvm-bufferization-segfault.patch b/mlir/patches/llvm-bufferization-segfault.patch
new file mode 100644
index 0000000000..e894516820
--- /dev/null
+++ b/mlir/patches/llvm-bufferization-segfault.patch
@@ -0,0 +1,27 @@
+diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+index 453ed43bcad..dff994729a4 100644
+--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
++++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+@@ -89,16 +89,12 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
+ /// Return the FuncOp called by `callOp`.
+ static FuncOp getCalledFunction(CallOpInterface callOp,
+                                 const AnalysisState &state) {
+-  auto &oneShotAnalysisState = static_cast(state);
+-
+-  if (auto *funcAnalysisState =
+-          oneShotAnalysisState.getExtension()) {
+-    // Use the cached symbol tables.
+-    return getCalledFunction(callOp, funcAnalysisState->symbolTables);
+-  }
+-
+-  SymbolTableCollection symbolTables;
+-  return getCalledFunction(callOp, symbolTables);
++  SymbolRefAttr sym =
++      llvm::dyn_cast_if_present(callOp.getCallableForCallee());
++  if (!sym)
++    return nullptr;
++  return dyn_cast_or_null(
++      SymbolTable::lookupNearestSymbolFrom(callOp, sym));
+ }
+ 
+ /// Get FuncAnalysisState.
diff --git a/mlir/patches/mhlo-add-back-necessary-passes.patch b/mlir/patches/mhlo-add-back-necessary-passes.patch
index b56ede8dd5..c430adae43 100644
--- a/mlir/patches/mhlo-add-back-necessary-passes.patch
+++ b/mlir/patches/mhlo-add-back-necessary-passes.patch
@@ -7,12 +7,12 @@ Subject: [PATCH] restore the removed mhlo passes we need:
 ---
  mhlo/transforms/CMakeLists.txt                |   6 +
  .../legalize_control_flow.cc                  | 288 +++++++++
- .../transforms/legalize_sort/legalize_sort.cc | 577 ++++++++++++++++++
+ .../transforms/legalize_sort/legalize_sort.cc | 578 ++++++++++++++++++
  .../legalize_to_standard.cc                   | 243 ++++++++
  .../legalize_to_standard_patterns.td          |  92 +++
  mhlo/transforms/mhlo_passes.td                |  19 +
  mhlo/transforms/passes.h                      |   4 +
- 7 files changed, 1229 insertions(+)
+ 7 files changed, 1230 insertions(+)
  create mode 100644 mhlo/transforms/legalize_control_flow/legalize_control_flow.cc
  create mode 100644 mhlo/transforms/legalize_sort/legalize_sort.cc
  create mode 100644 mhlo/transforms/legalize_to_standard/legalize_to_standard.cc
@@ -342,7 +342,7 @@ new file mode 100644
 index 00000000..8ba9de9a
 --- /dev/null
 +++ b/mhlo/transforms/legalize_sort/legalize_sort.cc
-@@ -0,0 +1,577 @@
+@@ -0,0 +1,578 @@
 +/* Copyright 2019 The OpenXLA Authors.
 +
 +Licensed under the Apache License, Version 2.0 (the "License");
@@ -883,8 +883,9 @@ index 00000000..8ba9de9a
 +
 +    SmallVector outputTensors;
 +    for (auto [out0, out1] : llvm::zip(outputMemrefs, scratchMemrefs)) {
++      Value s = b.create(parity, out1, out0).getResult();
 +      outputTensors.push_back(b.create(
-+          b.create(parity, out1, out0), /*restrict=*/true));
++          memref::getTensorTypeFromMemRefType(s.getType()), s, /*restrict=*/true));
 +    }
 +
 +    rewriter.replaceOp(op, outputTensors);
diff --git a/mlir/patches/mhlo-remove-shardy.patch b/mlir/patches/mhlo-remove-shardy.patch
index f78200bdab..32ce71061f 100644
--- a/mlir/patches/mhlo-remove-shardy.patch
+++ b/mlir/patches/mhlo-remove-shardy.patch
@@ -84,7 +84,7 @@ index cabd6a9f..2e64b4ed 100644
            patterns->add(context);
            patterns->add(context);
            patterns->add(context);
--          populateSdyShapeRefinementPatterns(patterns, context);
+-          populateSdyShapeRefinementPatterns(context, patterns);
          };
  
      if (failed(stablehlo::refineEntryFunction(*context, func,
@@ -92,7 +92,7 @@ index cabd6a9f..2e64b4ed 100644
    patterns->add(context);
    patterns->add(context);
    patterns->add(context);
--  populateSdyShapeRefinementPatterns(patterns, context);
+-  populateSdyShapeRefinementPatterns(context, patterns);
  }
  
  }  // namespace stablehlo_ext
diff --git a/mlir/patches/mhlo-rename-sort.patch b/mlir/patches/mhlo-rename-sort.patch
new file mode 100644
index 0000000000..c356cc35e3
--- /dev/null
+++ b/mlir/patches/mhlo-rename-sort.patch
@@ -0,0 +1,15 @@
+diff --git a/utils/cycle_detector.cc b/utils/cycle_detector.cc
+index e3901ae88..890f39654 100644
+--- a/utils/cycle_detector.cc
++++ b/utils/cycle_detector.cc
+@@ -199,8 +199,8 @@ static void backwardDfs(GraphCycles::Rep* r, int32_t n, int32_t lowerBound) {
+ // Recomputes rank assignments to make them compatible with the edges (producer
+ // has smaller rank than its consumer)
+ static void reorder(GraphCycles::Rep* r) {
+-  sort(r->nodes, &r->deltab);
+-  sort(r->nodes, &r->deltaf);
++  mlir::sort(r->nodes, &r->deltab);
++  mlir::sort(r->nodes, &r->deltaf);
+ 
+   // Adds contents of delta lists to list (backwards deltas first).
+   r->list.clear();
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 6394246da0..e3a8b603eb 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -3,6 +3,8 @@ include(AddMLIRPython)
 # TODO: Add an upstream cmake param for this vs having a global here.
 add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mlir_quantum.")
 
+# Ignore nanobind warnings
+add_compile_options(-w)
 
 ################################################################################
 # Declare Dialect Sources
diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir
index effc229a64..67f186943d 100644
--- a/mlir/test/Catalyst/BufferizationTest.mlir
+++ b/mlir/test/Catalyst/BufferizationTest.mlir
@@ -23,7 +23,7 @@
 
 func.func @dbprint_val(%arg0: tensor) {
 
-    // CHECK: %0 = bufferization.to_memref %arg0
+    // CHECK: %0 = bufferization.to_buffer %arg0
     // CHECK: "catalyst.print"(%0) : (memref) -> ()
     "catalyst.print"(%arg0) : (tensor) -> ()
 
@@ -34,7 +34,7 @@ func.func @dbprint_val(%arg0: tensor) {
 
 func.func @dbprint_memref(%arg0: tensor) {
 
-    // CHECK: %0 = bufferization.to_memref %arg0
+    // CHECK: %0 = bufferization.to_buffer %arg0
     // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> ()
     "catalyst.print"(%arg0) {print_descriptor} : (tensor) -> ()
 
@@ -54,7 +54,7 @@ func.func @dbprint_str() {
 // -----
 
 func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> {
-    // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0
+    // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0
     // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64>
     // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} :
     // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> ()
@@ -72,7 +72,7 @@ func.func @custom_call_copy(%arg0: tensor<2x3xf64>) -> tensor<2x2xf64> {
     // COM: e.g. coming from tensor subviews
     // COM: a copy needs to be performed because the kernels only allow for contiguous arrays as inputs
     //
-    // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0
+    // CHECK: [[sourceAlloc:%.+]] = bufferization.to_buffer %arg0
     // CHECK: [[subview:%.+]] = memref.subview [[sourceAlloc]]
     // CHECK-SAME: memref<2x3xf64> to memref<2x2xf64, strided<[3, 1]>>
     // CHECK: [[copyAlloc:%.+]] = memref.alloc() : memref<2x2xf64>
@@ -106,7 +106,7 @@ module @test1 {
   // CHECK-LABEL: @foo(
   // CHECK-SAME: [[arg0:%.+]]: tensor)
   func.func private @foo(%arg0: tensor) -> tensor {
-    // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : tensor to memref
+    // CHECK-DAG: [[memref0:%.+]] = bufferization.to_buffer [[arg0]] : tensor to memref
     // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref
     // CHECK:     catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> ()
     %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor)
diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir
index 20ce7c6bc7..9a9fc17102 100644
--- a/mlir/test/Catalyst/ConversionTest.mlir
+++ b/mlir/test/Catalyst/ConversionTest.mlir
@@ -146,13 +146,13 @@ module @test1 {
   // CHECK-SAME: [[arg0:%.+]]: tensor
   // CHECK-SAME:)
   func.func private @foo(%arg0: tensor) -> tensor {
-    // CHECK: [[memref0:%.+]] = bufferization.to_memref [[arg0]]
+    // CHECK: [[memref0:%.+]] = bufferization.to_buffer [[arg0]]
     // CHECK: [[ptr0:%.+]] = llvm.alloca {{.*}}
     // CHECK: [[ptr1:%.+]] = llvm.alloca {{.*}}
     // CHECK: [[struct0:%.+]] = builtin.unrealized_conversion_cast [[memref0]]
 
     // CHECK: [[tensor1:%.+]] = bufferization.alloc_tensor()
-    // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]]
+    // CHECK: [[memref1:%.+]] = bufferization.to_buffer [[tensor1]]
     // CHECK: [[struct1:%.+]] = builtin.unrealized_conversion_cast [[memref1]]
 
     // CHECK: llvm.store [[struct0]], [[ptr1]]
@@ -160,9 +160,9 @@ module @test1 {
 
     // call @callback_1([[ptr0]], [[ptr1]])
 
-    %0 = bufferization.to_memref %arg0 : tensor to memref
+    %0 = bufferization.to_buffer %arg0 : tensor to memref
     %1 = bufferization.alloc_tensor() {memory_space = 0 : i64} : tensor
-    %2 = bufferization.to_memref %1 : tensor to memref
+    %2 = bufferization.to_buffer %1 : tensor to memref
 
 
     catalyst.callback_call @callback_1(%0, %2) : (memref, memref) -> ()
diff --git a/mlir/test/Gradient/BufferizationTest.mlir b/mlir/test/Gradient/BufferizationTest.mlir
index 4a8f9a246e..8e84995888 100644
--- a/mlir/test/Gradient/BufferizationTest.mlir
+++ b/mlir/test/Gradient/BufferizationTest.mlir
@@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
 // CHECK-LABEL: @adjoint_with_tensor_arg
 func.func @adjoint_with_tensor_arg(%arg0: tensor<2xf64>, %arg1: index) {
 
-    // CHECK:   [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
+    // CHECK:   [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64>
     // CHECK:   [[alloc:%.+]] = memref.alloc(%arg1) : memref
     // CHECK:   gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref) : (memref<2xf64>) -> ()
     %grad = gradient.adjoint @circuit(%arg0) size(%arg1) : (tensor<2xf64>) -> tensor
@@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
 // CHECK-LABEL: @adjoint_with_multiple_results
 func.func @adjoint_with_multiple_results(%arg0: tensor<2xf64>, %arg1: index) {
 
-    // CHECK:   [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
+    // CHECK:   [[argBuffer:%.+]] = bufferization.to_buffer %arg0 : tensor<2xf64> to memref<2xf64>
     // CHECK:   [[alloc0:%.+]] = memref.alloc(%arg1) : memref
     // CHECK:   [[alloc1:%.+]] = memref.alloc(%arg1) : memref
     // CHECK:   gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]]
@@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64)
 // CHECK-LABEL: @backprop_scalar_in
 func.func @backprop_scalar_in(%arg0: f64, %arg1: tensor) {
 
-    // CHECK:   [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref
+    // CHECK:   [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref
     // CHECK:   [[dim1:%.+]] = memref.dim [[cotangentSource]]
     // CHECK:   [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref
     // CHECK:   memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor)
 // CHECK-LABEL: @backprop_tensor_in
 func.func @backprop_tensor_in(%arg0: tensor, %arg1: tensor) {
 
-    // CHECK-DAG:   [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor to memref
-    // CHECK-DAG:   [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor to memref
+    // CHECK-DAG:   [[argSource:%.+]] = bufferization.to_buffer %arg0 : tensor to memref
+    // CHECK-DAG:   [[cotangentSource:%.+]] = bufferization.to_buffer %arg1 : tensor to memref
     // CHECK:   [[dim2:%.+]] = memref.dim [[cotangentSource]]
     // CHECK:   [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref
     // CHECK:   memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>)
 // CHECK-LABEL: @backprop_multiple_tensors_in
 func.func @backprop_multiple_tensors_in(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>, %arg2: tensor) {
 
-    // CHECK-DAG:   [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64>
-    // CHECK-DAG:   [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64>
+    // CHECK-DAG:   [[argSource0:%.+]] = bufferization.to_buffer %arg0 : tensor<10xf64> to memref<10xf64>
+    // CHECK-DAG:   [[argSource1:%.+]] = bufferization.to_buffer %arg1 : tensor<2xf64> to memref<2xf64>
     // CHECK:   memref.alloc
     // CHECK:   memref.copy
     // CHECK:   [[argShadow1:%.+]] = memref.alloc() : memref<10xf64>
@@ -171,8 +171,8 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor, ten
 
     // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
     // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
-    // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref
-    // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
+    // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref
+    // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
     // CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref, memref<2xf64>
 
     %0:2 = func.call @callback_fn_fwd(%arg0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
@@ -192,7 +192,7 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor, %arg1: tensor<2xf64>)
     // CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64>
     // CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref
     // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor) -> tensor<2xf64>
-    // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
+    // CHECK: [[res:%.+]] = bufferization.to_buffer [[callOut]] : tensor<2xf64> to memref<2xf64>
     // CHECK: gradient.return {empty = true} [[res]] : memref<2xf64>
 
     %0 = func.call @callback_fn_vjp(%arg1, %arg0) : (tensor<2xf64>, tensor) -> tensor<2xf64>
diff --git a/mlir/test/Gradient/FiniteDifferenceTest.mlir b/mlir/test/Gradient/FiniteDifferenceTest.mlir
index 13af9e7956..37d70471d6 100644
--- a/mlir/test/Gradient/FiniteDifferenceTest.mlir
+++ b/mlir/test/Gradient/FiniteDifferenceTest.mlir
@@ -161,7 +161,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
     // CHECK:        [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
     // CHECK:        [[DIFF:%.+]] = tensor.generate
     // CHECK-NEXT:   ^bb0(%arg2: index, %arg3: index):
-    // CHECK:            [[MEMREF:%.+]] = bufferization.to_memref %arg0
+    // CHECK:            [[MEMREF:%.+]] = bufferization.to_buffer %arg0
     // CHECK:            [[COPY:%.+]] = bufferization.clone [[MEMREF]]
     // CHECK:            [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
     // CHECK:            [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
@@ -188,7 +188,7 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
     // CHECK:        [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
     // CHECK:        [[DIFF:%.+]] = tensor.generate
     // CHECK-NEXT:   ^bb0(%arg2: index, %arg3: index):
-    // CHECK:            [[MEMREF:%.+]] = bufferization.to_memref %arg0
+    // CHECK:            [[MEMREF:%.+]] = bufferization.to_buffer %arg0
     // CHECK:            [[COPY:%.+]] = bufferization.clone [[MEMREF]]
     // CHECK:            [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
     // CHECK:            [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
@@ -227,7 +227,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at
     // CHECK:        [[BASE:%.+]]:2 = call @funcMultiRes(%arg0)
     // CHECK:        [[DIFF:%.+]] = tensor.generate
     // CHECK-NEXT:   ^bb0(%arg1: index):
-    // CHECK:            [[MEMREF:%.+]] = bufferization.to_memref %arg0
+    // CHECK:            [[MEMREF:%.+]] = bufferization.to_buffer %arg0
     // CHECK:            [[COPY:%.+]] = bufferization.clone [[MEMREF]]
     // CHECK:            [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
     // CHECK:            [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg1]
@@ -239,7 +239,7 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at
     // CHECK:        [[R0:%.+]] = arith.divf [[DIFF]]
     // CHECK:        [[DIFF:%.+]] = tensor.generate
     // CHECK-NEXT:   ^bb0(%arg1: index, %arg2: index):
-    // CHECK:            [[MEMREF:%.+]] = bufferization.to_memref %arg0
+    // CHECK:            [[MEMREF:%.+]] = bufferization.to_buffer %arg0
     // CHECK:            [[COPY:%.+]] = bufferization.clone [[MEMREF]]
     // CHECK:            [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
     // CHECK:            [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg2]
@@ -279,7 +279,7 @@ func.func private @funcDynamicTensor(%arg0: tensor) -> tensor<2x?xf64>
 
     // CHECK:        [[DIFF:%.+]] = tensor.generate [[DDIM0]], [[DDIM1]]
     // CHECK-NEXT:   ^bb0([[i0:%.+]]: index, [[i1:%.+]]: index, [[i2:%.+]]: index, [[i3:%.+]]: index):
-    // CHECK:            [[MEMREF:%.+]] = bufferization.to_memref %arg0
+    // CHECK:            [[MEMREF:%.+]] = bufferization.to_buffer %arg0
     // CHECK:            [[COPY:%.+]] = bufferization.clone [[MEMREF]]
     // CHECK:            [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
     // CHECK:            [[VAL:%.+]] = tensor.extract [[TENSOR]][[[i2]], [[i3]]]
diff --git a/mlir/test/Gradient/PostProcessingTest.mlir b/mlir/test/Gradient/PostProcessingTest.mlir
index 2403372410..9ae25800f1 100644
--- a/mlir/test/Gradient/PostProcessingTest.mlir
+++ b/mlir/test/Gradient/PostProcessingTest.mlir
@@ -25,15 +25,15 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: memref<2xf64>) -> (memref, mem
 
     // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
     // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
-    // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor to memref
-    // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
+    // CHECK: [[res0:%.+]] = bufferization.to_buffer [[callOut]]#0 : tensor to memref
+    // CHECK: [[res1:%.+]] = bufferization.to_buffer [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
     // CHECK: memref.copy [[res0]], %arg2 : memref to memref
     // CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64>
 
 	%0 = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
 	%1:2 = func.call @callback_fn_fwd(%0) : (tensor<2xf64>) -> (tensor, tensor<2xf64>)
-	%2 = bufferization.to_memref %1#0 : tensor to memref
-	%3 = bufferization.to_memref %1#1 : tensor<2xf64> to memref<2xf64>
+	%2 = bufferization.to_buffer %1#0 : tensor