From 64207122dbd838fa9d84a6b3f54768d45bf41fc0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 13:29:55 -0400 Subject: [PATCH 01/16] Update [ghstack-poisoned] --- backends/aoti/aoti_model_container.cpp | 6 ++++++ backends/aoti/aoti_model_container.h | 16 ++++++++++++++++ backends/aoti/common_shims.cpp | 5 +++++ backends/aoti/common_shims.h | 3 +++ 4 files changed, 30 insertions(+) diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 03be835a0c3..d1764451ab6 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +// Global function pointers needed by Metal backend +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; + } // extern "C" } // namespace aoti diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 9b185327172..88d936d21ba 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +// Function pointer types needed by Metal backend +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Global function pointers needed by Metal backend +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; + } // extern "C" // AOTI Delegate Handle structure diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..7802444e97e 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -145,6 +145,11 @@ void cleanup_tensor_metadata() { internal::tensor_to_strides.clear(); } +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + } // extern "C" } // namespace aoti diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..97fcea1085c 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled); // Cleanup functions for clearing global state void cleanup_tensor_metadata(); +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype); + } // extern "C" } // namespace aoti From d036c0713348e2482aea4d21405d30a51b629f76 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:08:15 -0400 Subject: [PATCH 02/16] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 173 ++++++++++++++++++ backends/apple/metal/metal_partitioner.py | 77 ++++++++ .../metal/replace_slice_copy_with_slice.py | 118 ++++++++++++ backends/apple/metal/tests/__init__.py | 6 + .../apple/metal/tests/test_metal_backend.py | 80 ++++++++ .../metal/tests/test_metal_partitioner.py | 172 +++++++++++++++++ 6 files changed, 626 insertions(+) create mode 100644 backends/apple/metal/metal_backend.py create mode 100644 backends/apple/metal/metal_partitioner.py create mode 100644 backends/apple/metal/replace_slice_copy_with_slice.py create mode 100644 backends/apple/metal/tests/__init__.py create mode 100644 backends/apple/metal/tests/test_metal_backend.py create mode 100644 backends/apple/metal/tests/test_metal_partitioner.py diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py new file mode 100644 index 00000000000..782aa522084 --- /dev/null +++ b/backends/apple/metal/metal_backend.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from enum import Enum + +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, +} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + print("entering the lowerable parts in MetalBackend.preprocess....") + # Move the edge_program from CPU to MPS for aoti compile + mps_edge_program = move_to_device_pass(edge_program, "mps") + + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + + edge_program_module = mps_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = mps_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in mps_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Base options for all devices + options: dict[str, typing.Any] = { + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Package model constants and other generated files directly in the shared object (.so) file + "aot_inductor.package_constants_in_so": True, + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # "aot_inductor.debug_compile": True, + # "aot_inductor.force_mmap_weights": False, + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # pyre-ignorep[6]: Incompatible parameter type + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store = NamedDataStore() + method_name = MetalBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_metal_blob" + ) + + # Clean up the generated so file; it has been packaged into the NamdeDataStore + # pyre-ignorep[6]: Incompatible parameter type + os.remove(so_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Returns the compile spec representing the model compute precision, for additional details + please refer to the documentation for ``coremltools.precision``. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py new file mode 100644 index 00000000000..b103ac0f455 --- /dev/null +++ b/backends/apple/metal/metal_partitioner.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalPartitioner(Partitioner): + """ + Metal partitioner for AOTInductor backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the Metal backend to handle decomposition using + AOTInductor's MPS-specific decomposition table. + + Only operators that cannot be handled by the aoti-mps library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + node.meta["delegation_tag"] = tag + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/apple/metal/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..4f16759af35 --- /dev/null +++ b/backends/apple/metal/replace_slice_copy_with_slice.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, Iterable, Tuple + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py new file mode 100644 index 00000000000..fd6404c7f7b --- /dev/null +++ b/backends/apple/metal/tests/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py new file mode 100644 index 00000000000..26d2281c458 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.apple.metal.metal_backend import ( + COMPILE_SPEC_KEYS, + MetalBackend, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class TestMetalBackend(unittest.TestCase): + """Test Metal backend utility functions.""" + + def test_generate_method_name_compile_spec(self): + """Test that compile spec is generated correctly with method name.""" + method_name = "forward" + compile_spec = MetalBackend.generate_method_name_compile_spec(method_name) + + # Verify compile spec structure + self.assertIsInstance(compile_spec, CompileSpec) + self.assertEqual(compile_spec.key, COMPILE_SPEC_KEYS.METHOD_NAME.value) + self.assertEqual(compile_spec.value, method_name.encode("utf-8")) + + def test_method_name_from_compile_specs(self): + """Test extracting method name from compile specs.""" + method_name = "forward" + compile_specs = [MetalBackend.generate_method_name_compile_spec(method_name)] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_with_multiple_specs(self): + """Test extracting method name when there are multiple compile specs.""" + method_name = "forward" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + MetalBackend.generate_method_name_compile_spec(method_name), + CompileSpec("another_key", b"another_value"), + ] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_missing(self): + """Test that RuntimeError is raised when method name is missing.""" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + ] + + # Should raise RuntimeError when method name is not found + with self.assertRaises(RuntimeError) as context: + MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertIn("Could not find method name", str(context.exception)) + + def test_compile_spec_roundtrip(self): + """Test that method name survives encode/decode roundtrip.""" + original_name = "my_custom_method" + + # Generate compile spec + compile_spec = MetalBackend.generate_method_name_compile_spec(original_name) + + # Extract from compile specs list + extracted_name = MetalBackend.method_name_from_compile_specs([compile_spec]) + + self.assertEqual(original_name, extracted_name) + + +if __name__ == "__main__": + unittest.main() + diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py new file mode 100644 index 00000000000..97a073152f5 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestMetalPartitioner(unittest.TestCase): + """ + Test Metal partitioner functionality. + + After Metal partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the Metal backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that Metal partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + # Create test inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Get partition result + partition_result = self._get_partition_result(AddModule(), (x, y)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + # Verify exactly one partition tag exists + self.assertEqual( + len(partition_result.partition_tags), + 1, + "Expected exactly one partition tag for fully delegated graph", + ) + + def test_linear_partition(self): + """ + Test Metal partitioner with a linear layer. + All matrix operations should be in a single partition. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Create test input + x = torch.randn(2, 10) + + # Get partition result + partition_result = self._get_partition_result(LinearModule(), (x,)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + def test_ops_to_not_decompose(self): + """ + Test that ops_to_not_decompose returns all call_function ops. + Metal backend should handle decomposition via AOTInductor. + """ + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.relu(x + 1.0) + + # Create test input + x = torch.randn(2, 3) + + # Export the model + exported_program = export(SimpleModule(), (x,), strict=True) + + # Create partitioner + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get ops to not decompose + ops_to_not_decompose, _ = partitioner.ops_to_not_decompose(exported_program) + + # Verify it returns a list + self.assertIsInstance(ops_to_not_decompose, list) + + # All call_function ops should be in the list + call_function_ops = [ + node.target + for node in exported_program.graph.nodes + if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + ] + + self.assertEqual( + set(ops_to_not_decompose), + set(call_function_ops), + "ops_to_not_decompose should contain all call_function ops", + ) + + +if __name__ == "__main__": + unittest.main() + From 1a22c5e02f3a4d57bb3abf6b68c1279ec4bf58e4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:21:06 -0400 Subject: [PATCH 03/16] Update [ghstack-poisoned] --- backends/apple/metal/tests/__init__.py | 1 - backends/apple/metal/tests/test_metal_backend.py | 1 - backends/apple/metal/tests/test_metal_partitioner.py | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py index fd6404c7f7b..2e41cd717f6 100644 --- a/backends/apple/metal/tests/__init__.py +++ b/backends/apple/metal/tests/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py index 26d2281c458..5caf7a3adc6 100644 --- a/backends/apple/metal/tests/test_metal_backend.py +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -77,4 +77,3 @@ def test_compile_spec_roundtrip(self): if __name__ == "__main__": unittest.main() - diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py index 97a073152f5..1b29410ab6c 100644 --- a/backends/apple/metal/tests/test_metal_partitioner.py +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -157,7 +157,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: call_function_ops = [ node.target for node in exported_program.graph.nodes - if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + if node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) ] self.assertEqual( @@ -169,4 +170,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": unittest.main() - From d6f0bc952a57a6ec14c5118fd1fe4da91d2d3194 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:30 -0400 Subject: [PATCH 04/16] Update [ghstack-poisoned] --- .../metal/runtime/shims/tensor_attribute.cpp | 38 ++++++++ .../metal/runtime/shims/tensor_attribute.h | 32 +++++++ backends/apple/metal/runtime/shims/types.h | 35 +++++++ backends/apple/metal/runtime/shims/utils.cpp | 93 +++++++++++++++++++ backends/apple/metal/runtime/shims/utils.h | 74 +++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.cpp create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.h create mode 100644 backends/apple/metal/runtime/shims/types.h create mode 100644 backends/apple/metal/runtime/shims/utils.cpp create mode 100644 backends/apple/metal/runtime/shims/utils.h diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.cpp b/backends/apple/metal/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..684e00ffe32 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type constant +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + // Let's use 2 for MPS + return 2; +} + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.h b/backends/apple/metal/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..8d2a3dde361 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type function +int32_t aoti_torch_device_type_mps(); + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/types.h b/backends/apple/metal/runtime/shims/types.h new file mode 100644 index 00000000000..07d377d7499 --- /dev/null +++ b/backends/apple/metal/runtime/shims/types.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp new file mode 100644 index 00000000000..484158e9027 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_metal(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. it is ok if provided strides here is not contiguous + // strides since it will be used internally in CUDA delegate. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes using ExecutorTorch's algorithm + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h new file mode 100644 index 00000000000..5b9f9c5b3bb --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Enum for supported data types in et-metal backend +enum class SupportedDTypes : int32_t { + // UINT8 = 0, // PyTorch's uint8 dtype code + // INT8 = 1, // PyTorch's int8 dtype code + // INT16 = 2, // PyTorch's int16 dtype code + // INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + // FLOAT16 = 5, // PyTorch's float16 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + // FLOAT64 = 7, // PyTorch's float64 dtype code + // BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype); + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr); + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr); + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_contiguous_tensor( + std::vector sizes, + std::vector strides) { + int64_t ndim = static_cast(strides.size()); + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +} // namespace metal +} // namespace backends +} // namespace executorch From 7e11615aa43033df7f5988d5c1adb6d923c78297 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:34 -0400 Subject: [PATCH 05/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 271 ++++++ .../apple/metal/runtime/shims/et_metal.mm | 872 ++++++++++++++++++ 2 files changed, 1143 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/et_metal.h create mode 100644 backends/apple/metal/runtime/shims/et_metal.mm diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h new file mode 100644 index 00000000000..c18ad513a3a --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -0,0 +1,271 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __OBJC__ +#import +#import +#include +// Forward declarations for MetalPerformanceShadersGraph types +@class MPSGraph; +@class MPSCommandBuffer; +// Metal type definitions for Objective-C compilation +typedef id MTLDevice_t; +typedef id MTLCommandQueue_t; +typedef id MTLCommandBuffer_t; +typedef id MTLComputeCommandEncoder_t; +typedef id MTLComputePipelineState_t; +typedef id MTLFunction_t; +typedef id MTLLibrary_t; +typedef id MTLBuffer_t; +typedef dispatch_queue_t dispatch_queue_t; +typedef MPSGraph* MPSGraph_t; +typedef MPSCommandBuffer* MPSCommandBuffer_t; +typedef NSDictionary* NSDictionary_t; +#else +// Forward declarations for C++ compilation +typedef void* MTLDevice_t; +typedef void* MTLCommandQueue_t; +typedef void* MTLCommandBuffer_t; +typedef void* MTLComputeCommandEncoder_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLFunction_t; +typedef void* MTLLibrary_t; +typedef void* MTLBuffer_t; +typedef void* dispatch_queue_t; +typedef void* MPSGraph_t; +typedef void* MPSCommandBuffer_t; +typedef void* NSDictionary_t; +#endif + +#include +#include +#include +#include +#include + +namespace executorch::runtime::etensor { +class Tensor; +} + +namespace executorch { +namespace backends { +namespace metal { + +// Forward declarations +class ETMetalKernelFunction; +class ETMetalStream; + +// ======================= +// SyncType - Metal synchronization options +// ======================= +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command + // buffer + COMMIT_ADAPTIVE, // commit adaptively based on available memory +}; + +// ======================= +// ETMetalShaderLibrary - ExecuTorch Metal shader library management +// ======================= +class ETMetalShaderLibrary { + public: + ETMetalShaderLibrary(const std::string& source); + ~ETMetalShaderLibrary(); + + std::shared_ptr getKernelFunction( + const std::string& name); + + private: + void compileLibrary(); + std::pair getLibraryPipelineState( + const std::string& functionName); + + friend class ETMetalKernelFunction; + + std::string shaderSource_; + MTLLibrary_t library_; + std::unordered_map< + std::string, + std::pair> + pipelineStates_; +}; + +// ======================= +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution +// ======================= +class ETMetalKernelFunction { + public: + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); + ~ETMetalKernelFunction(); + + void startEncoding(); + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); + void setArg(unsigned idx, int64_t val); + + void dispatchSingle(uint64_t length); + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); + void dispatchArray(const uint64_t* length, size_t length_size); + void dispatchArrayWithGroupSize( + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + + void runCommandBlock(std::function f); + + private: + MTLComputePipelineState_t cps_; + MTLFunction_t func_; + MTLComputeCommandEncoder_t encoder_; +}; + +// ======================= +// ETMetalStream - Metal command buffer and synchronization management +// ======================= +class ETMetalStream { + public: + ETMetalStream(); + ~ETMetalStream(); + + // Get the default stream (singleton) + static ETMetalStream* getDefaultStream(); + + // Device and queue access + MTLDevice_t device() const { + return device_; + } + MTLCommandQueue_t commandQueue() const { + return commandQueue_; + } + dispatch_queue_t queue() const { + return serialQueue_; + } + + // Synchronization methods + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); + void synchronize(); // Overload for backward compatibility + bool isEmpty() const; + + // Command buffer management with lazy creation + MPSCommandBuffer_t commandBuffer(); + MTLComputeCommandEncoder_t commandEncoder(); + + void endKernelCoalescing(); + + // MPSGraph execution + void executeMPSGraph( + MPSGraph_t mpsGraph, + NSDictionary_t feeds, + NSDictionary_t results, + SyncType syncType = SyncType::COMMIT_ADAPTIVE); + + // Command buffer lifecycle management + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); + void flush(); + + // Memory operations + void fill( + MTLBuffer_t buffer, + uint8_t value, + size_t length, + size_t offset, + SyncType syncType = SyncType::NONE); + void copy( + MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + SyncType syncType = SyncType::NONE); + + private: + // Private synchronization methods + void commit(); + void commitAndWait(); + void commitAndContinue(); + + private: + // Private members + MTLDevice_t device_; + MTLCommandQueue_t commandQueue_; + MPSCommandBuffer_t commandBuffer_; + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern + MTLComputeCommandEncoder_t commandEncoder_; + dispatch_queue_t serialQueue_; // For thread safety + + // Configuration + bool enableCommitAndContinue_; + + // Singleton instance + static ETMetalStream* defaultStream_; +}; + +// ======================= +// Global storage management functions +// ======================= +void storeFunctionHandle( + ETMetalKernelFunction* raw_function, + std::shared_ptr function_shared_ptr); +void storeLibraryHandle( + ETMetalShaderLibrary* raw_library, + std::unique_ptr library); +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); + +// ======================= +// Global stream access functions +// ======================= +ETMetalStream* getCurrentMetalStream(); +void setCurrentMetalStream(ETMetalStream* stream); + +// ======================= +// Metal stream synchronization functions (C++ interface with exceptions) +// ======================= +void synchronize_metal_stream(); +void synchronize_metal_stream_with_type(int sync_type); + +// ======================= +// Metal helper functions (C interface) +// ======================= +#ifdef __cplusplus +extern "C" { +#endif + +// Memory management functions for Metal +void* metal_allocate_buffer(long bytes); +bool metal_is_device_pointer(void* ptr); +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool dst_is_device); +void metal_cleanup_resources(); + +// Helper functions to access Metal objects +MTLDevice_t get_metal_device(); +MTLCommandQueue_t get_metal_command_queue(); + +#ifdef __cplusplus +} + +// C++ only - expose the Metal buffer mapping +#ifdef __OBJC__ +extern std::unordered_map ptr_to_mtl_buffer; +#endif + +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm new file mode 100644 index 00000000000..5afcf761d56 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -0,0 +1,872 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// ======================= +// Exception-Safe Dispatch Function (similar to PyTorch MPS) +// ======================= + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { + __block std::optional block_exception; + dispatch_sync(queue, ^() { + try { + block(); + } catch (...) { + block_exception = std::current_exception(); + } + }); + if (block_exception) { + std::rethrow_exception(*block_exception); + } +} + +// ======================= +// Global Variables and Storage +// ================ + + +// Global Metal buffer mapping - accessible for MPS shim +std::unordered_map> ptr_to_mtl_buffer; + +// Global storage to keep shared_ptr alive while raw pointers are used +static std::unordered_map> function_storage; +static std::unordered_map> library_storage; + +// Static singleton instance for default stream +ETMetalStream* ETMetalStream::defaultStream_ = nullptr; + +// Thread-local current stream +static thread_local ETMetalStream* currentStream_ = nullptr; + +// ======================= +// Metal Helper Functions (C Interface) +// ======================= + +extern "C" { + +void* metal_allocate_buffer(long bytes) { + ETMetalStream* stream = getCurrentMetalStream(); + id device = stream->device(); + if (!device) { + ET_LOG(Error, "Failed to get Metal device from stream"); + return nullptr; + } + + @autoreleasepool { + id buffer = [device newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "Failed to allocate %ld bytes on Metal device", bytes); + return nullptr; + } + + void* ptr = [buffer contents]; + ptr_to_mtl_buffer[ptr] = buffer; + + ET_LOG(Debug, "Allocated %ld bytes on Metal device", bytes); + return ptr; + } +} + +void metal_cleanup_resources() { + if (!ptr_to_mtl_buffer.empty()) { + @autoreleasepool { + for (auto& pair : ptr_to_mtl_buffer) { + pair.second = nil; + } + ptr_to_mtl_buffer.clear(); + } + } +} + +bool metal_is_device_pointer(void* ptr) { + return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); +} + +int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_device, bool dst_is_device) { + if (!src || !dst || nbytes == 0) { + ET_LOG(Error, "Metal copy: Invalid parameters"); + return -1; + } + + @autoreleasepool { + // Case 1: Device-to-device copy - use GPU blit encoder (most efficient) + if (src_is_device && dst_is_device) { + auto src_it = ptr_to_mtl_buffer.find(const_cast(src)); + auto dst_it = ptr_to_mtl_buffer.find(dst); + + if (src_it != ptr_to_mtl_buffer.end() && dst_it != ptr_to_mtl_buffer.end()) { + id srcBuffer = src_it->second; + id dstBuffer = dst_it->second; + + // Calculate offsets relative to buffer base + size_t srcOffset = static_cast(src) - static_cast([srcBuffer contents]); + size_t dstOffset = static_cast(dst) - static_cast([dstBuffer contents]); + + // Use Metal's blit encoder for GPU-accelerated copy + ETMetalStream* stream = getCurrentMetalStream(); + stream->copy(srcBuffer, dstBuffer, nbytes, srcOffset, dstOffset, SyncType::NONE); + + ET_LOG(Debug, "Metal device-to-device copy (GPU blit): %zu bytes", nbytes); + return 0; + } + + ET_LOG(Error, "Metal copy: Device pointers not found in buffer map"); + return -1; + } + + // Case 2: Host-to-device or device-to-host - use memcpy with shared memory + // Since Metal uses shared storage mode, CPU and GPU access the same memory + std::memcpy(dst, src, nbytes); + + // Synchronize only if we need to ensure GPU operations complete before CPU reads + // (device-to-host case where GPU may have written data) + if (src_is_device && !dst_is_device) { + // Ensure any pending GPU writes to source complete before CPU reads + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + } + + ET_LOG(Debug, "Metal memory copy (memcpy): %zu bytes, src_device=%d, dst_device=%d", + nbytes, src_is_device, dst_is_device); + } + + return 0; +} + +id get_metal_device() { + // Use stream-based device access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->device(); +} + +id get_metal_command_queue() { + // Use stream-based queue access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->commandQueue(); +} + +} // extern "C" + +// ======================= +// ETMetalShaderLibrary Implementation +// ======================= + +ETMetalShaderLibrary::ETMetalShaderLibrary(const std::string& source) : shaderSource_(source) { + compileLibrary(); +} + +ETMetalShaderLibrary::~ETMetalShaderLibrary() { + @autoreleasepool { + if (library_) { + [library_ release]; + library_ = nil; + } + + for (auto& pair : pipelineStates_) { + [pair.second.first release]; + [pair.second.second release]; + } + pipelineStates_.clear(); + } +} + +void ETMetalShaderLibrary::compileLibrary() { + @autoreleasepool { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return; + } + + NSString* sourceString = [NSString stringWithUTF8String:shaderSource_.c_str()]; + NSError* error = nil; + + library_ = [device newLibraryWithSource:sourceString options:nil error:&error]; + if (!library_ || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to compile shader library: %s", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + return; + } + + [library_ retain]; + ET_LOG(Debug, "ETMetalShaderLibrary: Successfully compiled shader library"); + } +} + +std::pair, id> ETMetalShaderLibrary::getLibraryPipelineState(const std::string& functionName) { + auto it = pipelineStates_.find(functionName); + if (it != pipelineStates_.end()) { + return it->second; + } + + @autoreleasepool { + if (!library_) { + ET_LOG(Error, "ETMetalShaderLibrary: Library not compiled"); + return {nil, nil}; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return {nil, nil}; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName.c_str()]; + id function = [library_ newFunctionWithName:funcName]; + if (!function) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get function '%s'", functionName.c_str()); + return {nil, nil}; + } + + NSError* error = nil; + id pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + if (!pipelineState || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to create pipeline state for '%s': %s", + functionName.c_str(), error ? [[error localizedDescription] UTF8String] : "unknown error"); + [function release]; + return {nil, nil}; + } + + [pipelineState retain]; + [function retain]; + pipelineStates_[functionName] = {pipelineState, function}; + + ET_LOG(Debug, "ETMetalShaderLibrary: Created pipeline state for function '%s'", functionName.c_str()); + return {pipelineState, function}; + } +} + +std::shared_ptr ETMetalShaderLibrary::getKernelFunction(const std::string& name) { + auto pipelineStatePair = getLibraryPipelineState(name); + if (!pipelineStatePair.first || !pipelineStatePair.second) { + ET_LOG(Error, "ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for '%s'", name.c_str()); + return nullptr; + } + + return std::make_shared(pipelineStatePair.first, pipelineStatePair.second); +} + +// ======================= +// ETMetalKernelFunction Implementation +// ======================= + +ETMetalKernelFunction::ETMetalKernelFunction(id cps, id func) + : cps_(cps), func_(func), encoder_(nil) { + if (cps_) [cps_ retain]; + if (func_) [func_ retain]; +} + +ETMetalKernelFunction::~ETMetalKernelFunction() { + @autoreleasepool { + // Don't release encoder_ here - the stream owns it + // Only clean up our own references + if (cps_) { + [cps_ release]; + cps_ = nil; + } + if (func_) { + [func_ release]; + func_ = nil; + } + + encoder_ = nil; // Clear reference without releasing + } +} + +void ETMetalKernelFunction::startEncoding() { + @autoreleasepool { + // Don't retain/release the encoder - just get reference from stream + ETMetalStream* stream = getCurrentMetalStream(); + encoder_ = stream->commandEncoder(); // Use stream's managed encoder + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction: Failed to get encoder from stream"); + return; + } + + // Don't retain - stream owns the encoder + [encoder_ setComputePipelineState:cps_]; + + ET_LOG(Debug, "ETMetalKernelFunction: Started encoding with stream-managed encoder"); + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + void* data_ptr = tensor.mutable_data_ptr(); + size_t totalSize = tensor.numel() * tensor.element_size(); + + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it != ptr_to_mtl_buffer.end()) { + // Use existing Metal buffer + id mtlBuffer = it->second; + [encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set Metal buffer at index %u (size: %zu)", idx, totalSize); + } else { + // Handle CPU tensor data + if (totalSize <= 4096) { + // Use setBytes for small data (more efficient) + [encoder_ setBytes:data_ptr length:totalSize atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set CPU tensor via setBytes at index %u (size: %zu)", idx, totalSize); + } else { + // Create temporary buffer for large data (should be rare) + @autoreleasepool { + id device = get_metal_device(); + if (device) { + id tempBuffer = [device newBufferWithBytes:data_ptr + length:totalSize + options:MTLResourceStorageModeShared]; + if (tempBuffer) { + [encoder_ setBuffer:tempBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set large CPU tensor via temporary buffer at index %u (size: %zu)", idx, totalSize); + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: Failed to create temporary buffer for index %u", idx); + } + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No Metal device available for index %u", idx); + } + } + } + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, int64_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(int64_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); +} + +void ETMetalKernelFunction::dispatchSingle(uint64_t length) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingleWithGroupSize: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = group_size > 0 ? std::min(group_size, maxThreadsPerGroup) : std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingleWithGroupSize: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchArray(const uint64_t* length, size_t length_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length[0]); + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArray: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::dispatchArrayWithGroupSize(const uint64_t* length, size_t length_size, + const uint64_t* group_size, size_t group_size_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = maxThreadsPerGroup; + if (group_size && group_size_size > 0) { + actualGroupSize = std::min(maxThreadsPerGroup, group_size[0]); + } + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + if (group_size && group_size_size >= 2) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + } + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + if (group_size && group_size_size >= 3) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + groupZ = std::min(static_cast(group_size[2]), length_size > 2 ? length[2] : 1); + } + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::runCommandBlock(std::function f) { + // Use dispatch_sync with the stream's serial queue for thread safety and synchronization + // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) + ETMetalStream* stream = getCurrentMetalStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + f(); + } + }); + + ET_LOG(Debug, "ETMetalKernelFunction::runCommandBlock: Executed command block with dispatch_sync"); +} + +// ======================= +// ETMetalStream Implementation +// ======================= + +ETMetalStream::ETMetalStream() + : device_(nil), commandQueue_(nil), commandBuffer_(nil), prevCommandBuffer_(nil), + commandEncoder_(nil), serialQueue_(nullptr), enableCommitAndContinue_(true) { + @autoreleasepool { + // Create device and command queue + device_ = MTLCreateSystemDefaultDevice(); + if (!device_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal device"); + return; + } + [device_ retain]; + + commandQueue_ = [device_ newCommandQueue]; + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal command queue"); + return; + } + [commandQueue_ retain]; + + // Create serial queue for thread safety + serialQueue_ = dispatch_queue_create("metal gpu stream", nullptr); + + ET_LOG(Debug, "ETMetalStream: Created stream with device %p, queue %p", device_, commandQueue_); + } +} + +ETMetalStream::~ETMetalStream() { + @autoreleasepool { + // Synchronize before cleanup + synchronize(SyncType::COMMIT_AND_WAIT); + + // Clean up command encoder + if (commandEncoder_) { + [commandEncoder_ release]; + commandEncoder_ = nil; + } + + // Clean up command buffers + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } + if (prevCommandBuffer_) { + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Clean up command queue and device + if (commandQueue_) { + [commandQueue_ release]; + commandQueue_ = nil; + } + if (device_) { + [device_ release]; + device_ = nil; + } + + // Clean up serial queue + if (serialQueue_) { + dispatch_release(serialQueue_); + serialQueue_ = nullptr; + } + + ET_LOG(Debug, "ETMetalStream: Destroyed stream"); + } +} + +ETMetalStream* ETMetalStream::getDefaultStream() { + if (!defaultStream_) { + defaultStream_ = new ETMetalStream(); + } + return defaultStream_; +} + +// Lazy command buffer creation (use MPSCommandBuffer like PyTorch) +MPSCommandBuffer* ETMetalStream::commandBuffer() { + if (!commandBuffer_) { + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: No command queue available"); + return nil; + } + + commandBuffer_ = [MPSCommandBuffer commandBufferFromCommandQueue:commandQueue_]; + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: Failed to create command buffer"); + return nil; + } + [commandBuffer_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandBuffer: Created lazy command buffer %p", commandBuffer_); + } + + return commandBuffer_; +} + +// Lazy command encoder creation +id ETMetalStream::commandEncoder() { + if (!commandEncoder_) { + MPSCommandBuffer* cmdBuffer = commandBuffer(); + if (!cmdBuffer) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to get command buffer"); + return nil; + } + + commandEncoder_ = [cmdBuffer computeCommandEncoder]; + if (!commandEncoder_) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to create command encoder"); + return nil; + } + [commandEncoder_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandEncoder: Created lazy command encoder %p", commandEncoder_); + } + + return commandEncoder_; +} + +// Synchronization with SyncType - matches PyTorch's approach (no dispatch_sync here) +void ETMetalStream::synchronize(SyncType syncType) { + endKernelCoalescing(); + + switch (syncType) { + case SyncType::NONE: + // Do nothing - no commit + break; + case SyncType::COMMIT: + commit(); + break; + case SyncType::COMMIT_AND_WAIT: + commitAndWait(); + break; + case SyncType::COMMIT_AND_CONTINUE: + if (enableCommitAndContinue_) { + commitAndContinue(); + } else { + ET_LOG(Error, "ETMetalStream::synchronize: CommitAndContinue requested but disabled"); + commit(); + } + break; + case SyncType::COMMIT_ADAPTIVE: + // Simple adaptive policy - could be enhanced with memory pressure detection + // TODO: Could add memory pressure detection like PyTorch does + commit(); + break; + } + + ET_LOG(Debug, "ETMetalStream::synchronize: Completed with SyncType %d", static_cast(syncType)); +} + +// Encoder coalescing management +void ETMetalStream::endKernelCoalescing() { + if (commandEncoder_) { + [commandEncoder_ endEncoding]; + [commandEncoder_ release]; + commandEncoder_ = nil; + ET_LOG(Debug, "ETMetalStream::endKernelCoalescing: Ended encoder coalescing"); + } +} + +// Commit methods +void ETMetalStream::commit() { + if (enableCommitAndContinue_ && commandBuffer_) { + // Use commit-and-continue for better performance + commitAndContinue(); + } else { + flush(); + } +} + +void ETMetalStream::commitAndWait() { + // Handle previous command buffer first + if (prevCommandBuffer_) { + [prevCommandBuffer_ waitUntilCompleted]; + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Handle current command buffer + if (commandBuffer_) { + [commandBuffer_ commit]; + [commandBuffer_ waitUntilCompleted]; + [commandBuffer_ release]; + commandBuffer_ = nil; + } + + ET_LOG(Debug, "ETMetalStream::commitAndWait: Committed and waited for completion"); +} + +void ETMetalStream::commitAndContinue() { + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commitAndContinue: No command buffer to commit"); + return; + } + + // Commit buffer and allow immediate reuse for better performance + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commitAndContinue: Committed buffer %p with continue", commandBuffer_); + + // The buffer handles synchronization internally for commit-and-continue +} + +void ETMetalStream::flush() { + if (commandBuffer_) { + [commandBuffer_ commit]; + + if (!enableCommitAndContinue_) { + // Keep the command buffer for later waiting if commit-and-continue is disabled + prevCommandBuffer_ = commandBuffer_; + } else { + [commandBuffer_ release]; + } + commandBuffer_ = nil; + + ET_LOG(Debug, "ETMetalStream::flush: Flushed command buffer"); + } +} + +// Memory operations +void ETMetalStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { + if (length == 0) { + return; + } + + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::fill: Filled buffer with value %u, length %zu, offset %zu", value, length, offset); + } + }); +} + +void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, + size_t srcOffset, size_t dstOffset, SyncType syncType) { + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + // Handle large copies in chunks + constexpr size_t max_copy_size = 0x80000000; // 2GB + size_t bytes_copied = 0; + size_t bytes_remaining = length; + + while (bytes_remaining > 0) { + NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remaining); + [blitEncoder copyFromBuffer:srcBuffer + sourceOffset:(NSUInteger)srcOffset + bytes_copied + toBuffer:dstBuffer + destinationOffset:(NSUInteger)dstOffset + bytes_copied + size:bytes_to_copy]; + bytes_copied += bytes_to_copy; + bytes_remaining -= bytes_to_copy; + } + + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::copy: Copied %zu bytes from offset %zu to offset %zu", length, srcOffset, dstOffset); + } + }); +} + + +void ETMetalStream::synchronize() { + synchronize(SyncType::COMMIT_AND_WAIT); +} + +bool ETMetalStream::isEmpty() const { + return !commandBuffer_ && !commandEncoder_; +} + +void ETMetalStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { + // Use dispatch_sync_with_rethrow exactly like PyTorch does for MPSGraph execution + dispatch_sync_with_rethrow(serialQueue_, ^() { + @autoreleasepool { + endKernelCoalescing(); + + [mpsGraph encodeToCommandBuffer:commandBuffer() + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:nil]; + + //synchronize(syncType); + } + }); +} + +// ======================= +// Global Storage Management Functions +// ======================= + +void storeFunctionHandle(ETMetalKernelFunction* raw_function, std::shared_ptr function_shared_ptr) { + function_storage[raw_function] = function_shared_ptr; +} + +void storeLibraryHandle(ETMetalShaderLibrary* raw_library, std::unique_ptr library) { + library_storage[raw_library] = std::move(library); +} + +bool removeFunctionHandle(ETMetalKernelFunction* raw_function) { + auto it = function_storage.find(raw_function); + if (it != function_storage.end()) { + function_storage.erase(it); + return true; + } + return false; +} + +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library) { + auto it = library_storage.find(raw_library); + if (it != library_storage.end()) { + library_storage.erase(it); + return true; + } + return false; +} + +// ======================= +// Global Stream Access Functions +// ======================= + +ETMetalStream* getCurrentMetalStream() { + if (!currentStream_) { + currentStream_ = ETMetalStream::getDefaultStream(); + } + return currentStream_; +} + +void setCurrentMetalStream(ETMetalStream* stream) { + currentStream_ = stream; +} + +// ======================= +// Metal Stream Synchronization Functions +// ======================= + +void synchronize_metal_stream() { + @autoreleasepool { + // Use the ETMetalStream for proper synchronization + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + + ET_LOG(Debug, "synchronize_metal_stream: Stream synchronized with COMMIT_AND_WAIT"); + } +} + +void synchronize_metal_stream_with_type(int sync_type) { + @autoreleasepool { + ETMetalStream* stream = getCurrentMetalStream(); + SyncType syncTypeEnum = static_cast(sync_type); + stream->synchronize(syncTypeEnum); + + ET_LOG(Debug, "synchronize_metal_stream_with_type: Stream synchronized with SyncType %d", sync_type); + } +} + +} // namespace metal +} // namespace backends +} // namespace executorch From dfa435ada3c8b2eacf0245eb55e3d946e894d75f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:39 -0400 Subject: [PATCH 06/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 453 ++++++++++++++++++ backends/apple/metal/runtime/shims/memory.h | 73 +++ 2 files changed, 526 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/memory.cpp create mode 100644 backends/apple/metal/runtime/shims/memory.h diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp new file mode 100644 index 00000000000..2bda93e18a4 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -0,0 +1,453 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include // Ensure we have int64_t, int32_t definitions +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Import all from aoti namespace +using namespace executorch::backends::aoti; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; +std::unordered_map is_tensor_own_memory; + +extern "C" { + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: entered"); + + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + data != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Handle storage offset by adjusting the data pointer + void* adjusted_data = static_cast(data) + + (storage_offset * dtype_to_element_size(dtype)); + + ET_LOG( + Debug, + "aoti_torch_create_tensor_from_blob_v2: original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data, + storage_offset, + dtype_to_element_size(dtype), + adjusted_data); + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: contiguous tensor"); + } else { + ET_LOG( + Debug, "aoti_torch_create_tensor_from_blob_v2: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::from_blob( + adjusted_data, sizes, strides, dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create tensor from blob"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = false; + + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch_empty_strided: entered"); + + // This requires us to reserve device memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + size_t element_size = dtype_to_element_size(dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + dtype); + int64_t nbytes = numel * element_size; + + if (device_type == 2) { // Metal/MPS + ptr = metal_allocate_buffer(nbytes); + if (!ptr) { + ET_LOG(Error, "Failed to allocate %lld bytes on Metal device", nbytes); + return Error::MemoryAllocationFailed; + } + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match device requirements + int result = posix_memalign(&ptr, 16, nbytes); + ET_CHECK_OR_RETURN_ERROR( + result == 0, + MemoryAllocationFailed, + "Failed to allocate aligned CPU memory"); + ET_CHECK_OR_RETURN_ERROR( + ptr != nullptr, + MemoryAllocationFailed, + "Failed to call posix_memalign"); + ET_LOG(Debug, "Allocated %lld bytes on CPU", nbytes); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_empty_strided: contiguous tensor"); + } else { + ET_LOG(Debug, "aoti_torch_empty_strided: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + auto tensor = + executorch::extension::from_blob(ptr, sizes, strides, scalar_type); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = true; + + ET_LOG(Debug, "aoti_torch_empty_strided: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { + ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered"); + // Find tensor in the set + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + auto tensor_ptr = *it; + + // Check ownership before cleaning up + auto ownership_it = is_tensor_own_memory.find(tensor); + bool owns_memory = (ownership_it != is_tensor_own_memory.end()) + ? ownership_it->second + : false; + + // Clean up ownership metadata + is_tensor_own_memory.erase(tensor); + + if (owns_memory) { + // et tensor owns the memory; need to free it manually + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Check if it's Metal GPU memory + if (metal_is_device_pointer(data_ptr)) { + // This is Metal GPU memory - the Metal helper will handle cleanup + // Metal buffers are automatically managed by ARC when the buffer is + // released + tensors.erase(it); + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: successfull (Metal GPU memory)"); + return Error::Ok; + } + + // This is CPU memory - free immediately + free(data_ptr); + } + // else: Don't free memory since the tensor doesn't own it + + // Remove from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + ET_LOG( + Debug, "aoti_torch_delete_tensor_object: successfull (CPU memory)"); + return Error::Ok; + } + } + ET_LOG(Error, "Didn't find tensor %p", tensor); + return Error::InvalidArgument; +} + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking) { + ET_LOG(Debug, "aoti_torch_copy_: entered"); + + (void)non_blocking; + + // Check for null pointers first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: src tensor is null"); + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype)); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype)); + + // Check dtype compatibility - both tensors must have the same dtype + ET_CHECK_OR_RETURN_ERROR( + self_dtype == src_dtype, + InvalidArgument, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + ET_CHECK_OR_RETURN_ERROR( + self_numel == src_numel, + InvalidArgument, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + bool srcIsDevice = false; + bool dstIsDevice = false; + + // Check if pointers are Metal device pointers + if (!srcIsDevice) { + srcIsDevice = metal_is_device_pointer(const_cast(src->data_ptr())); + } + if (!dstIsDevice) { + dstIsDevice = metal_is_device_pointer(self->mutable_data_ptr()); + } + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + // TODO: This should be improved to catch cases like (4, 1, 5) -> (4, 5) + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + int result = metal_copy_memory( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + srcIsDevice, + dstIsDevice); + if (result != 0) { + ET_LOG(Error, "metal_copy_memory failed with status %d", result); + return Error::Internal; + } + } else { + ET_LOG(Error, "Layout conversion not supported"); + return Error::NotImplemented; + } + + ET_LOG(Debug, "aoti_torch_copy_: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered"); + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + + // Get the dtype from the source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype)); + + // Validate dtype using SupportedDTypes + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + int32_t device_type = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + + // Get the base data pointer from the source tensor + void* base_data_ptr = self->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + base_data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Calculate new tensor size in elements for logging + int64_t new_numel = 1; + for (int64_t i = 0; i < ndim; i++) { + new_numel *= sizes_ptr[i]; + } + + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld", + base_data_ptr, + new_numel, + storage_offset); + + // Create a new tensor view that shares the same underlying storage + // This is the correct way to implement reinterpret_tensor - as a view, not a + // copy + AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2( + base_data_ptr, // Same underlying data pointer + ndim, // New dimensions + sizes_ptr, // New sizes + strides_ptr, // New strides + storage_offset, // Storage offset (will be handled properly now) + dtype, + device_type, + device_index, + ret_new_tensor, + 0, // layout (default) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_err != Error::Ok) { + ET_LOG(Error, "failed to create reinterpreted tensor view"); + return create_err; + } + + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); + return Error::Ok; +} + +// Cleanup function for clearing global state +void cleanup_memory() { + is_tensor_own_memory.clear(); + if (!tensors.empty()) { + ET_LOG(Error, "Warning: tensors not empty during cleanup"); + } + + // Clean up Metal resources + metal_cleanup_resources(); +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/memory.h b/backends/apple/metal/runtime/shims/memory.h new file mode 100644 index 00000000000..47fb6352b50 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Global storage declarations +extern std::unordered_map is_tensor_own_memory; +extern std::unordered_set> tensors; + +// Memory-related operations +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor); + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking); + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor); + +void cleanup_memory(); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch From 648ee077cce4f779c3a7353e43642b680c962f94 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:44 -0400 Subject: [PATCH 07/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/shim_mps.h | 118 ++++ .../apple/metal/runtime/shims/shim_mps.mm | 540 ++++++++++++++++++ 2 files changed, 658 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/shim_mps.h create mode 100644 backends/apple/metal/runtime/shims/shim_mps.mm diff --git a/backends/apple/metal/runtime/shims/shim_mps.h b/backends/apple/metal/runtime/shims/shim_mps.h new file mode 100644 index 00000000000..94611b016ae --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +struct AOTIMetalKernelFunctionOpaque; +using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; + +struct AOTIMetalShaderLibraryOpaque; +using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*; + +#ifdef __cplusplus +extern "C" { +#endif + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle); + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle); + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle); + +// MetalKernelFunction functions +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func); + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor); + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length); + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size); + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size); + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + +// Memory management functions +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes); + +AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + +// C callback function type for command block execution +typedef void (*aoti_torch_mps_command_block_callback_t)( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/shim_mps.mm b/backends/apple/metal/runtime/shims/shim_mps.mm new file mode 100644 index 00000000000..e5e7d8c0dc9 --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.mm @@ -0,0 +1,540 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +extern "C" { + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle) { + + if (!metal_shader_source || !library_handle) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: null arguments"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto library = std::make_unique(std::string(metal_shader_source)); + auto* raw_library = library.get(); + + // Store the unique_ptr to keep the object alive + storeLibraryHandle(raw_library, std::move(library)); + + // Return raw pointer to match existing API + *library_handle = reinterpret_cast(raw_library); + + ET_LOG(Debug, "aoti_torch_mps_create_shader_library: Created shader library %p", raw_library); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle) { + + if (!library_handle) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: null library handle"); + return Error::InvalidArgument; + } + + try { + auto* library = reinterpret_cast(library_handle); + if (removeLibraryHandle(library)) { + ET_LOG(Debug, "aoti_torch_mps_delete_shader_library: Deleted shader library %p", library); + } else { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: Library not found in storage"); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle) { + + if (!library_handle || !kernel_name || !function_handle) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: null arguments"); + return Error::InvalidArgument; + } + + try { + auto* library = reinterpret_cast(library_handle); + auto function_shared_ptr = library->getKernelFunction(std::string(kernel_name)); + if (!function_shared_ptr) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: Failed to get kernel function '%s'", kernel_name); + return Error::Internal; + } + + auto* raw_function = function_shared_ptr.get(); + + // Store the shared_ptr to keep the object alive + storeFunctionHandle(raw_function, function_shared_ptr); + + // Return raw pointer to match existing API + *function_handle = reinterpret_cast(raw_function); + + ET_LOG(Debug, "aoti_torch_mps_get_kernel_function: Got kernel function '%s' -> %p", kernel_name, raw_function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps_start_encoding: Started encoding for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_start_encoding exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor) { + + if (!func || !tensor) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: null function handle or tensor"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto* function = reinterpret_cast(func); + auto* et_tensor = reinterpret_cast(tensor); + + function->setArg(idx, *et_tensor); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_tensor: Set tensor argument at index %u", idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->setArg(idx, val); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_int: Set int64_t value %lld at index %u", val, idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingle(length); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single: Dispatched function %p with length %llu", function, length); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingleWithGroupSize(length, group_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single_with_group_size: Dispatched function %p with length %llu, group size %llu", function, length, group_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArray(length, length_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArrayWithGroupSize(length, length_size, group_size, group_size_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array_with_group_size: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes) { + if (num_bytes == 0) { + *buffer = nullptr; + return Error::Ok; + } + + if (!buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: null buffer pointer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to get Metal device"); + return Error::Internal; + } + + id metal_buffer = [device newBufferWithLength:num_bytes + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared]; + if (!metal_buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to allocate Metal buffer of size %zu", num_bytes); + return Error::Internal; + } + + // FIX: Return contents pointer, not buffer object + void* contents_ptr = [metal_buffer contents]; + ptr_to_mtl_buffer[contents_ptr] = metal_buffer; // Map contents to buffer + *buffer = contents_ptr; // Return contents pointer + + ET_LOG(Debug, "aoti_torch_mps_malloc: Allocated Metal buffer %p with contents %p of size %zu", + metal_buffer, contents_ptr, num_bytes); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_malloc exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_malloc: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_free(void* ptr) { + if (!ptr) { + return Error::Ok; // Nothing to free + } + + @autoreleasepool { + try { + // FIX: ptr is now the contents pointer, not the buffer object + // Look up the buffer from the mapping and clean up + auto it = ptr_to_mtl_buffer.find(ptr); + if (it != ptr_to_mtl_buffer.end()) { + id metal_buffer = it->second; + [metal_buffer release]; + ptr_to_mtl_buffer.erase(it); + ET_LOG(Debug, "aoti_torch_mps_free: Freed Metal buffer for contents %p", ptr); + } else { + ET_LOG(Error, "aoti_torch_mps_free: Buffer not found for contents pointer %p", ptr); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_free exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_free: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start) { + + if (!buffer || !constants_start) { + ET_LOG(Error, "aoti_torch_mps_memcpy: null buffer or constants_start"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // FIX: buffer is now the contents pointer, not the buffer object + auto buffer_pointer = static_cast(buffer); + + memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_memcpy: Failed to get Metal device"); + return Error::Internal; + } + id subBuffer = [device newBufferWithBytesNoCopy:buffer_pointer + constant_offset + length:data_size + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared + deallocator:nil]; + + if (constant_offset != 0) { + ptr_to_mtl_buffer[buffer_pointer + constant_offset] = subBuffer; // Map contents to buffer + } + + ET_LOG(Debug, "aoti_torch_mps_memcpy: Copied %zu bytes from offset %zu to buffer offset %zu", + data_size, bytes_read, constant_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_memcpy exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_memcpy: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset) { + + if (!src_buffer || !dst_buffer) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: null buffer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto src_mtl_buffer = (id)src_buffer; + auto dst_mtl_buffer = (id)dst_buffer; + + uint8_t* src_contents = static_cast([src_mtl_buffer contents]); + uint8_t* dst_contents = static_cast([dst_mtl_buffer contents]); + + if (!src_contents || !dst_contents) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: Failed to get buffer contents"); + return Error::Internal; + } + + memcpy(dst_contents + dst_offset, src_contents + src_offset, data_size); + + ET_LOG(Debug, "aoti_torch_mps_copy_buffer: Copied %zu bytes from src+%zu to dst+%zu", + data_size, src_offset, dst_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: unknown exception"); + return Error::Internal; + } + } +} + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Called with func=%p, user_data=%p", func, user_data); + + auto* function_wrapper = static_cast*>(user_data); + if (function_wrapper) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Calling function wrapper"); + (*function_wrapper)(func); + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Function wrapper completed"); + } else { + ET_LOG(Error, "aoti_torch_mps_shared_callback: null function wrapper"); + } +} + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null function handle"); + return Error::InvalidArgument; + } + + if (!callback) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null callback"); + return Error::InvalidArgument; + } + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Starting command block for function %p, callback %p, user_data %p", + func, callback, user_data); + + try { + auto* function = reinterpret_cast(func); + function->runCommandBlock([callback, func, user_data]() { + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Inside lambda, calling callback"); + callback(func, user_data); + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Callback completed"); + }); + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Executed command block for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_run_command_block exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + + +} // namespace metal +} // namespace backends +} // namespace executorch From 3bea537276bfea5ebbf1a254b2345cac34c959c3 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:48 -0400 Subject: [PATCH 08/16] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.h | 86 ++ .../apple/metal/runtime/shims/et_metal_ops.mm | 1255 +++++++++++++++++ 2 files changed, 1341 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/et_metal_ops.h create mode 100644 backends/apple/metal/runtime/shims/et_metal_ops.mm diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h new file mode 100644 index 00000000000..a334e439333 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * ExecutorTorch implementation of aoti_torch_mps_addmm_out. + * Performs matrix multiplication with bias: out = beta * self + alpha * (mat1 @ + * mat2) + */ +AOTITorchError aoti_torch_mps_addmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat1, + AOTITensorHandle mat2, + double beta, + double alpha); + +/** + * ExecutorTorch implementation of aoti_torch_mps_mm_out. + * Performs simple matrix multiplication: out = self @ mat2 + */ +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + +/** + * ExecutorTorch implementation of aoti_torch_mps_convolution. + * Performs 2D convolution operation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0); + +/** + * ExecutorTorch implementation of + * aoti_torch_mps__scaled_dot_product_attention_math_for_mps. Performs scaled + * dot product attention calculation - matches PyTorch AOTI signature + */ +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm new file mode 100644 index 00000000000..111263de972 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -0,0 +1,1255 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Forward declaration of dispatch_sync_with_rethrow from et_metal.mm +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +extern "C" { + +AOTITorchError aoti_torch_mps_mm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + ET_LOG(Debug, "aoti_torch_mps_mm_out: Starting with out=%p, self=%p, mat2=%p", + out, self, mat2); + + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_mm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Converted tensor handles to ET tensors"); + + // Validate tensor dimensions + if (self_tensor->dim() != 2 || mat2_tensor->dim() != 2) { + std::string error_msg = "aoti_torch_mps_mm_out: tensors must be 2-D, got " + + std::to_string(self_tensor->dim()) + " and " + + std::to_string(mat2_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + int64_t M = self_tensor->sizes()[0]; // rows of self + int64_t K = self_tensor->sizes()[1]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[1]; // cols of mat2 + + // Check matrix multiplication compatibility + if (self_tensor->sizes()[1] != mat2_tensor->sizes()[0]) { + std::string error_msg = "aoti_torch_mps_mm_out: incompatible matrix sizes for mm (" + + std::to_string(M) + "x" + std::to_string(K) + " and " + + std::to_string(mat2_tensor->sizes()[0]) + "x" + std::to_string(N) + ")"; + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_mm_out: self shape: [%d, %d], mat2 shape: [%d, %d], out shape: [%d, %d]", + (int)M, (int)K, (int)mat2_tensor->sizes()[0], (int)N, + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0); + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_mm_out: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers from tensors using the global mapping + void* self_data_ptr = self_tensor->mutable_data_ptr(); + void* mat2_data_ptr = mat2_tensor->mutable_data_ptr(); + void* out_data_ptr = out_tensor->mutable_data_ptr(); + + // Look up Metal buffers from the global mapping + auto self_it = ptr_to_mtl_buffer.find(self_data_ptr); + auto mat2_it = ptr_to_mtl_buffer.find(mat2_data_ptr); + auto out_it = ptr_to_mtl_buffer.find(out_data_ptr); + + if (self_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_mm_out: self tensor not found in Metal buffer mapping"); + throw std::runtime_error("self tensor not found in Metal buffer mapping"); + } + if (mat2_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_mm_out: mat2 tensor not found in Metal buffer mapping"); + throw std::runtime_error("mat2 tensor not found in Metal buffer mapping"); + } + if (out_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_mm_out: out tensor not found in Metal buffer mapping"); + throw std::runtime_error("out tensor not found in Metal buffer mapping"); + } + + id self_buffer = self_it->second; + id mat2_buffer = mat2_it->second; + id out_buffer = out_it->second; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using existing Metal buffers - self=%p, mat2=%p, out=%p", + self_buffer, mat2_buffer, out_buffer); + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Determine data type and element size + int32_t dtype = static_cast(self_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: self_tensor scalar_type=%d, SupportedDTypes::FLOAT32=%d, SupportedDTypes::BFLOAT16=%d", + dtype, static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16)); + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_mm_out: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for matrix multiplication"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: dtype=%d, element_size=%zu", dtype, element_size); + ET_LOG(Debug, "aoti_torch_mps_mm_out: M=%lld, K=%lld, N=%lld", M, K, N); + + // Create MPSGraph for matrix multiplication + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created MPSGraph instance"); + + // Define tensor shapes for placeholders + NSArray* selfShape = @[@(M), @(K)]; + NSArray* mat2Shape = @[@(K), @(N)]; + NSArray* outShape = @[@(M), @(N)]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", + (int)M, (int)K, (int)K, (int)N); + + // Create placeholders for input tensors + MPSGraphTensor* selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + dataType:mps_dtype + name:@"mat2"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); + + // Perform matrix multiplication using MPSGraph + MPSGraphTensor* mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Placeholder + name:@"matrix_multiplication"]; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + MPSGraphTensorData* mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2Shape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created feeds dictionary"); + + // Create results dictionary + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + NSDictionary* results = @{mmOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } + + ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_mm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_mm_out: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_addmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat1, + AOTITensorHandle mat2, + double beta, + double alpha) { + ET_LOG(Debug, "aoti_torch_mps_addmm_out: Starting with out=%p, self=%p, mat1=%p, mat2=%p, beta=%f, alpha=%f", + out, self, mat1, mat2, beta, alpha); + + if (!out || !self || !mat1 || !mat2) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat1_tensor = reinterpret_cast(mat1); + auto mat2_tensor = reinterpret_cast(mat2); + + ET_LOG(Debug, "aoti_torch_mps_addmm_out: Converted tensor handles to ET tensors"); + + // For now, just zero out the output tensor to get the right shape + // TODO: Implement actual matrix multiplication: out = beta * self + alpha * (mat1 @ mat2) + + // Get output data pointer and size + float* out_data = static_cast(out_tensor->mutable_data_ptr()); + size_t out_numel = out_tensor->numel(); + + if (!out_data) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: null output data pointer"); + return Error::InvalidArgument; + } + + // Zero out the output tensor + std::memset(out_data, 0, out_numel * sizeof(float)); + + ET_LOG(Debug, "aoti_torch_mps_addmm_out: Zeroed output tensor with %zu elements", out_numel); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_addmm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_addmm_out: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_convolution( + AOTITensorHandle input, + AOTITensorHandle weight, + AOTITensorHandle* bias, + const int64_t* stride, + int64_t stride_len_, + const int64_t* padding, + int64_t padding_len_, + const int64_t* dilation, + int64_t dilation_len_, + int32_t transposed, + const int64_t* output_padding, + int64_t output_padding_len_, + int64_t groups, + AOTITensorHandle* ret0) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Starting with input=%p, weight=%p, bias=%p, groups=%lld, transposed=%d", + input, weight, bias, groups, transposed); + + if (!input || !weight || !ret0) { + ET_LOG(Error, "aoti_torch_mps_convolution: null required handles (input, weight, or ret0)"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto input_tensor = reinterpret_cast(input); + auto weight_tensor = reinterpret_cast(weight); + + // bias can be null for convolutions without bias + executorch::runtime::etensor::Tensor* bias_tensor = nullptr; + if (bias && *bias) { + bias_tensor = reinterpret_cast(*bias); + ET_LOG(Debug, "aoti_torch_mps_convolution: Has bias tensor"); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: No bias tensor"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Converted tensor handles to ET tensors"); + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps_convolution: input shape: [%d, %d, %d, %d]", + input_tensor->dim() > 0 ? (int)input_tensor->sizes()[0] : 0, + input_tensor->dim() > 1 ? (int)input_tensor->sizes()[1] : 0, + input_tensor->dim() > 2 ? (int)input_tensor->sizes()[2] : 0, + input_tensor->dim() > 3 ? (int)input_tensor->sizes()[3] : 0); + + ET_LOG(Debug, "aoti_torch_mps_convolution: weight shape: [%d, %d, %d, %d]", + weight_tensor->dim() > 0 ? (int)weight_tensor->sizes()[0] : 0, + weight_tensor->dim() > 1 ? (int)weight_tensor->sizes()[1] : 0, + weight_tensor->dim() > 2 ? (int)weight_tensor->sizes()[2] : 0, + weight_tensor->dim() > 3 ? (int)weight_tensor->sizes()[3] : 0); + + // Log convolution parameters + if (stride && stride_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: stride: [%lld, %lld]", stride[0], stride[1]); + } + if (padding && padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: padding: [%lld, %lld]", padding[0], padding[1]); + } + if (dilation && dilation_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: dilation: [%lld, %lld]", dilation[0], dilation[1]); + } + if (output_padding && output_padding_len_ >= 2) { + ET_LOG(Debug, "aoti_torch_mps_convolution: output_padding: [%lld, %lld]", output_padding[0], output_padding[1]); + } + + // Support conv1d and conv2d by inspecting weight rank. + // conv1d: weight dims = [C_out, C_in, K] + // conv2d: weight dims = [C_out, C_in, Kh, Kw] + bool is_conv1d = (weight_tensor->dim() == 3); + + // Accept input ranks: + // conv1d: 2D (C,W) or 3D (N,C,W) + // conv2d: 3D (C,H,W) or 4D (N,C,H,W) + bool has_batch_dim = false; + bool is_input_4d = false; + int64_t N = 1, C_in = 0, H_in = 1, W_in = 0; + if (is_conv1d) { + if (input_tensor->dim() == 2) { + // (C, W) + has_batch_dim = false; + C_in = input_tensor->sizes()[0]; + W_in = input_tensor->sizes()[1]; + H_in = 1; + } else if (input_tensor->dim() == 3) { + // (N, C, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + H_in = 1; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv1d expects 2D or 3D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } else { + is_input_4d = (input_tensor->dim() == 4); + if (is_input_4d) { + // (N, C, H, W) + has_batch_dim = true; + N = input_tensor->sizes()[0]; + C_in = input_tensor->sizes()[1]; + H_in = input_tensor->sizes()[2]; + W_in = input_tensor->sizes()[3]; + } else if (input_tensor->dim() == 3) { + // (C, H, W) + has_batch_dim = false; + N = 1; + C_in = input_tensor->sizes()[0]; + H_in = input_tensor->sizes()[1]; + W_in = input_tensor->sizes()[2]; + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: conv2d expects 3D or 4D input, got %d", (int)input_tensor->dim()); + return Error::InvalidArgument; + } + } + + // Get weight dimensions + int64_t C_out = weight_tensor->sizes()[0]; // output channels + int64_t kernel_h = is_conv1d ? 1 : weight_tensor->sizes()[2]; // kernel height + int64_t kernel_w = is_conv1d ? weight_tensor->sizes()[2] : weight_tensor->sizes()[3]; // kernel width + + // Calculate output spatial dimensions + int64_t stride_h = is_conv1d ? 1 : (stride && stride_len_ > 0 ? stride[0] : 1); + int64_t stride_w = is_conv1d ? (stride && stride_len_ > 0 ? stride[0] : 1) + : (stride && stride_len_ > 1 ? stride[1] : 1); + int64_t pad_h = is_conv1d ? 0 : (padding && padding_len_ > 0 ? padding[0] : 0); + int64_t pad_w = is_conv1d ? (padding && padding_len_ > 0 ? padding[0] : 0) + : (padding && padding_len_ > 1 ? padding[1] : 0); + int64_t dil_h = is_conv1d ? 1 : (dilation && dilation_len_ > 0 ? dilation[0] : 1); + int64_t dil_w = is_conv1d ? (dilation && dilation_len_ > 0 ? dilation[0] : 1) + : (dilation && dilation_len_ > 1 ? dilation[1] : 1); + + int64_t H_out, W_out; + if (transposed) { + // For transposed convolution, output size calculation is different + int64_t output_pad_h = is_conv1d ? 0 : (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0); + int64_t output_pad_w = is_conv1d ? (output_padding && output_padding_len_ > 0 ? output_padding[0] : 0) + : (output_padding && output_padding_len_ > 1 ? output_padding[1] : 0); + H_out = is_conv1d ? 1 : ((H_in - 1) * stride_h - 2 * pad_h + dil_h * (kernel_h - 1) + output_pad_h + 1); + W_out = (W_in - 1) * stride_w - 2 * pad_w + dil_w * (kernel_w - 1) + output_pad_w + 1; + } else { + // Regular convolution output size calculation + H_out = is_conv1d ? 1 : ((H_in + 2 * pad_h - dil_h * (kernel_h - 1) - 1) / stride_h + 1); + W_out = (W_in + 2 * pad_w - dil_w * (kernel_w - 1) - 1) / stride_w + 1; + } + + if (!is_conv1d && is_input_4d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 4D output shape: [%lld, %lld, %lld, %lld]", N, C_out, H_out, W_out); + } else if (!is_conv1d) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D output shape: [%lld, %lld, %lld]", C_out, H_out, W_out); + } else if (is_conv1d && has_batch_dim) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 3D (1D conv) output shape: [%lld, %lld, %lld]", N, C_out, W_out); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Calculated 2D (1D conv) output shape: [%lld, %lld]", C_out, W_out); + } + + // Validate output dimensions are positive + if (N <= 0 || C_out <= 0 || H_out <= 0 || W_out <= 0) { + ET_LOG(Error, "aoti_torch_mps_convolution: Invalid output dimensions N=%lld, C_out=%lld, H_out=%lld, W_out=%lld", + N, C_out, H_out, W_out); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get current Metal stream"); + return Error::Internal; + } + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers from tensors using the global mapping + void* input_data_ptr = input_tensor->mutable_data_ptr(); + void* weight_data_ptr = weight_tensor->mutable_data_ptr(); + + // Look up Metal buffers from the global mapping + auto input_it = ptr_to_mtl_buffer.find(input_data_ptr); + auto weight_it = ptr_to_mtl_buffer.find(weight_data_ptr); + + if (input_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_convolution: input tensor not found in Metal buffer mapping"); + throw std::runtime_error("input tensor not found in Metal buffer mapping"); + } + if (weight_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_convolution: weight tensor not found in Metal buffer mapping"); + throw std::runtime_error("weight tensor not found in Metal buffer mapping"); + } + + id input_buffer = input_it->second; + id weight_buffer = weight_it->second; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Using existing Metal buffers - input=%p, weight=%p", + input_buffer, weight_buffer); + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Ensure stream is ready; command buffer handled internally by stream helpers + + // Determine data type and element size + int32_t dtype = static_cast(input_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps_convolution: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for convolution"); + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + + // Create MPSGraph for convolution + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created MPSGraph instance"); + + // Define tensor shapes for placeholders (always 4D NCHW for MPSGraph) + NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; + NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Creating placeholders with shapes input:[%d,%d,%d,%d] weight:[%d,%d,%d,%d]", + (int)N, (int)C_in, (int)H_in, (int)W_in, + (int)C_out, (int)C_in, (int)kernel_h, (int)kernel_w); + + // Create placeholders for input tensors + MPSGraphTensor* inputPlaceholder = [mpsGraph placeholderWithShape:inputShape + dataType:mps_dtype + name:@"input"]; + MPSGraphTensor* weightPlaceholder = [mpsGraph placeholderWithShape:weightShape + dataType:mps_dtype + name:@"weight"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created input and weight placeholders"); + + // Create convolution descriptor + MPSGraphConvolution2DOpDescriptor* convDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:pad_w + paddingRight:pad_w + paddingTop:pad_h + paddingBottom:pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created convolution descriptor with stride=[%lld,%lld], padding=[%lld,%lld], dilation=[%lld,%lld], groups=%lld", + stride_w, stride_h, pad_w, pad_h, dil_w, dil_h, groups); + + // Perform convolution using MPSGraph + MPSGraphTensor* convOutput = nil; + if (transposed) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using transposed convolution"); + // For transposed convolution, we need to handle output padding + int64_t output_pad_h = output_padding && output_padding_len_ > 0 ? output_padding[0] : 0; + int64_t output_pad_w = output_padding && output_padding_len_ > 1 ? output_padding[1] : 0; + + // For transposed convolution, we need to adjust the padding calculation + // In transposed convolution, the effective padding is typically negative + // and we use output_padding to control the final output size + int64_t transposed_pad_h = pad_h - output_pad_h; + int64_t transposed_pad_w = pad_w - output_pad_w; + + // Create transposed convolution descriptor with adjusted padding + MPSGraphConvolution2DOpDescriptor* transposedConvDesc = [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:stride_w + strideInY:stride_h + dilationRateInX:dil_w + dilationRateInY:dil_h + groups:groups + paddingLeft:transposed_pad_w + paddingRight:transposed_pad_w + paddingTop:transposed_pad_h + paddingBottom:transposed_pad_h + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; + + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:transposedConvDesc + name:@"transposed_convolution"]; + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Using regular convolution"); + convOutput = [mpsGraph convolution2DWithSourceTensor:inputPlaceholder + weightsTensor:weightPlaceholder + descriptor:convDesc + name:@"convolution"]; + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Successfully created convolution tensor"); + + // Handle bias if provided + MPSGraphTensor* finalOutput = convOutput; + MPSGraphTensor* biasPlaceholder = nil; + if (bias_tensor) { + ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); + + // Get bias tensor data + void* bias_data_ptr = bias_tensor->mutable_data_ptr(); + auto bias_it = ptr_to_mtl_buffer.find(bias_data_ptr); + + if (bias_it != ptr_to_mtl_buffer.end()) { + id bias_buffer = bias_it->second; + + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; + + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; + + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); + } else { + ET_LOG(Debug, "aoti_torch_mps_convolution: Bias tensor not found in Metal buffer mapping, skipping bias"); + } + } + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + MPSGraphTensorData* inputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:input_buffer + shape:inputShape + dataType:mps_dtype]; + MPSGraphTensorData* weightData = [[MPSGraphTensorData alloc] initWithMTLBuffer:weight_buffer + shape:weightShape + dataType:mps_dtype]; + + feeds[inputPlaceholder] = inputData; + feeds[weightPlaceholder] = weightData; + + // Add bias data to feeds if provided + if (bias_tensor && biasPlaceholder) { + void* bias_data_ptr = bias_tensor->mutable_data_ptr(); + auto bias_it = ptr_to_mtl_buffer.find(bias_data_ptr); + + if (bias_it != ptr_to_mtl_buffer.end()) { + id bias_buffer = bias_it->second; + NSArray* biasShape = @[@(C_out)]; + MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; + + feeds[biasPlaceholder] = biasData; + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); + } + } + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created feeds dictionary"); + + // Create or reuse output Metal buffer via AOTI API; keeps GPU residency + size_t output_size_bytes = N * C_out * H_out * W_out * element_size; + void* output_contents_ptr = nullptr; + AOTITorchError malloc_err = aoti_torch_mps_malloc(&output_contents_ptr, output_size_bytes); + if (malloc_err != Error::Ok || !output_contents_ptr) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to allocate Metal buffer via aoti_torch_mps_malloc"); + throw std::runtime_error("Failed to allocate output Metal buffer"); + } + + auto out_it = ptr_to_mtl_buffer.find(output_contents_ptr); + if (out_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps_convolution: aoti_torch_mps_malloc did not register buffer in map"); + throw std::runtime_error("Failed to look up allocated Metal buffer"); + } + id output_buffer = out_it->second; + + // Create results dictionary (MPSGraph output is 4D) + NSArray* outputShape = @[@(N), @(C_out), @(H_out), @(W_out)]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:output_buffer + shape:outputShape + dataType:mps_dtype]; + + NSDictionary* results = @{finalOutput: outputData}; + ET_LOG(Debug, "aoti_torch_mps_convolution: Created results dictionary"); + + // Execute the MPSGraph + ET_LOG(Debug, "aoti_torch_mps_convolution: Executing MPSGraph"); + + @try { + // Use stream helper to encode and synchronize correctly + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph execution failed with NSException"); + } + // } @catch (const std::exception& e) { + // ET_LOG(Error, "aoti_torch_mps_convolution exception: %s", e.what()); + // throw std::runtime_error("MPSGraph execution failed"); + // } + + ET_LOG(Debug, "aoti_torch_mps_convolution: MPSGraph execution completed successfully"); + + // Create output tensor handle on device (MPS) that points to GPU buffer + std::vector output_sizes_int64; + std::vector output_strides; + if (!is_conv1d && is_input_4d) { + output_sizes_int64 = {N, C_out, H_out, W_out}; + // Contiguous NCHW strides + output_strides = { + C_out * H_out * W_out, + H_out * W_out, + W_out, + 1 + }; + } else if (!is_conv1d) { + output_sizes_int64 = {C_out, H_out, W_out}; + // Contiguous CHW strides + output_strides = { + H_out * W_out, + W_out, + 1 + }; + } else if (is_conv1d && has_batch_dim) { + output_sizes_int64 = {N, C_out, W_out}; + // Contiguous NCW strides + output_strides = { + C_out * W_out, + W_out, + 1 + }; + } else { + output_sizes_int64 = {C_out, W_out}; + // Contiguous CW strides + output_strides = { + W_out, + 1 + }; + } + + // Use the GPU buffer contents pointer directly for the tensor storage + void* tensor_data = output_contents_ptr; + + AOTITensorHandle output_tensor_handle = nullptr; + + AOTITorchError create_result = aoti_torch_create_tensor_from_blob_v2( + tensor_data, + static_cast(output_sizes_int64.size()), // ndim + output_sizes_int64.data(), + output_strides.data(), + 0, // storage_offset + dtype, // dtype + 2, // device_type (MPS) + 0, // device_index + &output_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_result != Error::Ok || !output_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps_convolution: Failed to create output tensor, error code: %d", static_cast(create_result)); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Failed to create output tensor"); + } + + // Verify the tensor was created with the correct size + auto* et_tensor = reinterpret_cast(output_tensor_handle); + size_t actual_numel = et_tensor->numel(); + size_t expected_numel = static_cast(N * C_out * H_out * W_out); + + if (actual_numel != expected_numel) { + ET_LOG(Error, "aoti_torch_mps_convolution: Tensor size mismatch. Expected %zu, got %zu", expected_numel, actual_numel); + aoti_torch_mps_free(tensor_data); // Free the allocated GPU memory on failure + throw std::runtime_error("Tensor size mismatch"); + } + + // Store the tensor handle - mark that we own the memory since we manually allocated it with malloc + *ret0 = output_tensor_handle; + is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory + + ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_convolution exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( + AOTITensorHandle query, + AOTITensorHandle key, + AOTITensorHandle value, + AOTITensorHandle* attn_mask, + double dropout_p, + int32_t is_causal, + AOTITensorHandle* dropout_mask, + double* scale, + AOTITensorHandle* ret0, + AOTITensorHandle* ret1) { + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Starting with MPSGraph implementation"); + + if (!query || !key || !value || !ret0 || !ret1) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: null required tensor handles"); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get current Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto* query_tensor = reinterpret_cast(query); + auto* key_tensor = reinterpret_cast(key); + auto* value_tensor = reinterpret_cast(value); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); + + // Validate tensor dimensions + if (query_tensor->dim() < 3 || key_tensor->dim() < 3 || value_tensor->dim() < 3) { + std::string error_msg = "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: tensors must be at least 3-D, got " + + std::to_string(query_tensor->dim()) + ", " + + std::to_string(key_tensor->dim()) + ", " + + std::to_string(value_tensor->dim()); + ET_LOG(Error, "%s", error_msg.c_str()); + throw std::runtime_error(error_msg); + } + + // Get tensor dimensions (assuming [batch, num_heads, seq_len, head_dim] format) + int64_t batchSize = query_tensor->sizes()[0]; + int64_t num_heads = query_tensor->sizes()[1]; + int64_t qSize = query_tensor->sizes()[2]; + int64_t headSize = query_tensor->sizes()[3]; + int64_t kvSeqLength = key_tensor->sizes()[2]; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", + batchSize, num_heads, qSize, headSize, kvSeqLength); + + // Determine data type and element size + int32_t dtype = static_cast(query_tensor->scalar_type()); + MPSDataType mps_dtype; + size_t element_size; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + element_size = sizeof(float); + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + element_size = sizeof(uint16_t); // bfloat16 is 16 bits + } else { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported data type: %d", dtype); + throw std::runtime_error("Unsupported data type for scaled dot product attention"); + } + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + + // Calculate scale factor + double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast(headSize))); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor); + + // Get Metal device + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to get Metal device"); + throw std::runtime_error("Failed to get Metal device"); + } + + // Get Metal buffers for input tensors + void* query_data_ptr = query_tensor->mutable_data_ptr(); + void* key_data_ptr = key_tensor->mutable_data_ptr(); + void* value_data_ptr = value_tensor->mutable_data_ptr(); + + id query_buffer = nullptr; + id key_buffer = nullptr; + id value_buffer = nullptr; + + // Look up Metal buffers from the global mapping + auto query_it = ptr_to_mtl_buffer.find(query_data_ptr); + auto key_it = ptr_to_mtl_buffer.find(key_data_ptr); + auto value_it = ptr_to_mtl_buffer.find(value_data_ptr); + + if (query_it != ptr_to_mtl_buffer.end()) { + query_buffer = query_it->second; + } + if (key_it != ptr_to_mtl_buffer.end()) { + key_buffer = key_it->second; + } + if (value_it != ptr_to_mtl_buffer.end()) { + value_buffer = value_it->second; + } + + // Create temporary Metal buffers if not found in mapping + if (!query_buffer) { + size_t query_size = query_tensor->numel() * element_size; + query_buffer = [device newBufferWithBytes:query_data_ptr + length:query_size + options:MTLResourceStorageModeShared]; + if (!query_buffer) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for query tensor"); + throw std::runtime_error("Failed to create Metal buffer for query tensor"); + } + } + + if (!key_buffer) { + size_t key_size = key_tensor->numel() * element_size; + key_buffer = [device newBufferWithBytes:key_data_ptr + length:key_size + options:MTLResourceStorageModeShared]; + if (!key_buffer) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for key tensor"); + throw std::runtime_error("Failed to create Metal buffer for key tensor"); + } + } + + if (!value_buffer) { + size_t value_size = value_tensor->numel() * element_size; + value_buffer = [device newBufferWithBytes:value_data_ptr + length:value_size + options:MTLResourceStorageModeShared]; + if (!value_buffer) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for value tensor"); + throw std::runtime_error("Failed to create Metal buffer for value tensor"); + } + } + + // Calculate output tensor dimensions + std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; + std::vector attn_sizes = {batchSize, num_heads, qSize, kvSeqLength}; + + // Calculate strides for contiguous tensors + std::vector out_strides = { + num_heads * qSize * headSize, + qSize * headSize, + headSize, + 1 + }; + + std::vector attn_strides = { + num_heads * qSize * kvSeqLength, + qSize * kvSeqLength, + kvSeqLength, + 1 + }; + + // Allocate output Metal buffers via AOTI API to keep GPU residency and reuse + size_t out_size_bytes = batchSize * num_heads * qSize * headSize * element_size; + size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; + + void* out_contents_ptr = nullptr; + AOTITorchError out_malloc_err = aoti_torch_mps_malloc(&out_contents_ptr, out_size_bytes); + if (out_malloc_err != Error::Ok || !out_contents_ptr) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to allocate out buffer via aoti_torch_mps_malloc"); + throw std::runtime_error("Failed to allocate output buffer"); + } + auto out_map_it = ptr_to_mtl_buffer.find(out_contents_ptr); + if (out_map_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: out buffer not found in mapping after malloc"); + aoti_torch_mps_free(out_contents_ptr); + throw std::runtime_error("Mapping for out buffer missing"); + } + id out_buffer = out_map_it->second; + + void* attn_contents_ptr = nullptr; + AOTITorchError attn_malloc_err = aoti_torch_mps_malloc(&attn_contents_ptr, attn_size_bytes); + if (attn_malloc_err != Error::Ok || !attn_contents_ptr) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to allocate attn buffer via aoti_torch_mps_malloc"); + aoti_torch_mps_free(out_contents_ptr); + throw std::runtime_error("Failed to allocate attn buffer"); + } + auto attn_map_it = ptr_to_mtl_buffer.find(attn_contents_ptr); + if (attn_map_it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: attn buffer not found in mapping after malloc"); + aoti_torch_mps_free(out_contents_ptr); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Mapping for attn buffer missing"); + } + id attn_weights_buffer = attn_map_it->second; + + // End any existing kernel coalescing to ensure a clean state for MPS + stream->endKernelCoalescing(); + + // Method 1: Using MPSGraph scaledDotProductAttention API - with detailed error handling + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Implementing using MPSGraph scaledDotProductAttention"); + + @try { + // Check if scaledDotProductAttentionWithQueryTensor is available + MPSGraph* testGraph = [MPSGraph new]; + if (![testGraph respondsToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)]) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API not available on this system"); + throw std::runtime_error("scaledDotProductAttentionWithQueryTensor API not available on this system"); + } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scaledDotProductAttentionWithQueryTensor API is available"); + + // Create MPSGraph for scaled dot product attention + MPSGraph* mpsGraph = [MPSGraph new]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); + + // Define tensor shapes for placeholders + NSArray* queryShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; + NSArray* keyShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + NSArray* valueShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Creating placeholders with shapes Q:[%d,%d,%d,%d] K:[%d,%d,%d,%d] V:[%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)qSize, (int)headSize, + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize, + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + + // Create placeholders for input tensors + MPSGraphTensor* queryPlaceholder = [mpsGraph placeholderWithShape:queryShape + dataType:mps_dtype + name:@"query"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created query placeholder"); + + MPSGraphTensor* keyPlaceholder = [mpsGraph placeholderWithShape:keyShape + dataType:mps_dtype + name:@"key"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created key placeholder"); + + MPSGraphTensor* valuePlaceholder = [mpsGraph placeholderWithShape:valueShape + dataType:mps_dtype + name:@"value"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created value placeholder"); + + MPSGraphTensor* maskTensor = nil; + + // Handle causal mask + if (is_causal) { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Creating causal mask"); + + // Create a causal mask: lower triangular matrix filled with 0s, upper triangle with -inf + // Shape should be [qSize, kvSeqLength] + NSArray* maskShape = @[@(qSize), @(kvSeqLength)]; + + // Create ones tensor + MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f + shape:maskShape + dataType:mps_dtype]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created ones tensor for causal mask"); + + // Create lower triangular mask (including diagonal) + MPSGraphTensor* causalMask = [mpsGraph bandPartWithTensor:onesTensor + numLower:-1 + numUpper:0 + name:@"causal_mask"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created causal mask using bandPartWithTensor"); + + // Convert mask to attention weights format: 0 for allowed positions, -inf for masked + MPSGraphTensor* zerosTensor = [mpsGraph constantWithScalar:0.0f + shape:maskShape + dataType:mps_dtype]; + + MPSGraphTensor* negInfTensor = [mpsGraph constantWithScalar:-1e9f + shape:maskShape + dataType:mps_dtype]; + + // Select: where causal_mask == 1, use 0.0, else use -inf + maskTensor = [mpsGraph selectWithPredicateTensor:causalMask + truePredicateTensor:zerosTensor + falsePredicateTensor:negInfTensor + name:@"causal_mask_final"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created final causal mask using selectWithPredicateTensor"); + } + + // Handle explicit attention mask if provided + MPSGraphTensor* explicitMaskPlaceholder = nil; + if (attn_mask && *attn_mask) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Adding explicit attention mask"); + + // Create mask placeholder + NSMutableArray* maskShapeArray = [NSMutableArray array]; + for (int i = 0; i < mask_tensor->dim(); i++) { + [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; + } + + explicitMaskPlaceholder = [mpsGraph placeholderWithShape:maskShapeArray + dataType:mps_dtype + name:@"attention_mask"]; + + if (maskTensor) { + // Combine causal and explicit masks + maskTensor = [mpsGraph additionWithPrimaryTensor:maskTensor + secondaryTensor:explicitMaskPlaceholder + name:@"combined_mask"]; + } else { + maskTensor = explicitMaskPlaceholder; + } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created explicit mask placeholder"); + } + + // Perform scaled dot product attention using MPSGraph + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Calling scaledDotProductAttentionWithQueryTensor with scale=%f", scale_factor); + + MPSGraphTensor* outputTensor = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryPlaceholder + keyTensor:keyPlaceholder + valueTensor:valuePlaceholder + maskTensor:maskTensor + scale:scale_factor + name:@"scaled_dot_product_attention"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Successfully created SDPA tensor"); + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created feeds dictionary"); + + // Create MPSGraphTensorData objects for input tensors + MPSGraphTensorData* queryData = [[MPSGraphTensorData alloc] initWithMTLBuffer:query_buffer + shape:queryShape + dataType:mps_dtype]; + MPSGraphTensorData* keyData = [[MPSGraphTensorData alloc] initWithMTLBuffer:key_buffer + shape:keyShape + dataType:mps_dtype]; + MPSGraphTensorData* valueData = [[MPSGraphTensorData alloc] initWithMTLBuffer:value_buffer + shape:valueShape + dataType:mps_dtype]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraphTensorData objects for inputs"); + + feeds[queryPlaceholder] = queryData; + feeds[keyPlaceholder] = keyData; + feeds[valuePlaceholder] = valueData; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added input tensors to feeds"); + + // Add explicit mask data to feeds if provided + if (explicitMaskPlaceholder && attn_mask && *attn_mask) { + auto* mask_tensor = reinterpret_cast(*attn_mask); + void* mask_data_ptr = mask_tensor->mutable_data_ptr(); + + // Get or create Metal buffer for mask + id mask_buffer = nullptr; + auto mask_it = ptr_to_mtl_buffer.find(mask_data_ptr); + if (mask_it != ptr_to_mtl_buffer.end()) { + mask_buffer = mask_it->second; + } else { + size_t mask_size = mask_tensor->numel() * element_size; + mask_buffer = [device newBufferWithBytes:mask_data_ptr + length:mask_size + options:MTLResourceStorageModeShared]; + if (!mask_buffer) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for attention mask"); + throw std::runtime_error("Failed to create Metal buffer for attention mask"); + } + } + + NSMutableArray* maskShapeArray = [NSMutableArray array]; + for (int i = 0; i < mask_tensor->dim(); i++) { + [maskShapeArray addObject:@(mask_tensor->sizes()[i])]; + } + + MPSGraphTensorData* maskData = [[MPSGraphTensorData alloc] initWithMTLBuffer:mask_buffer + shape:maskShapeArray + dataType:mps_dtype]; + feeds[explicitMaskPlaceholder] = maskData; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Added explicit mask tensor to feeds"); + } + + // Create results dictionary + NSArray* outputShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; + MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outputShape + dataType:mps_dtype]; + + NSDictionary* results = @{outputTensor: outputData}; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created results dictionary"); + + // Execute via shared stream and keep results on GPU + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executing MPSGraph using stream"); + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_CONTINUE); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph execution completed successfully"); + + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: NSException caught: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + throw std::runtime_error("MPSGraph operation failed with NSException"); + } + + // For attention weights, zero-fill the GPU buffer (shared memory allows CPU memset) + std::memset(attn_contents_ptr, 0, attn_size_bytes); + + // Create output tensor handles + AOTITensorHandle out_tensor_handle = nullptr; + AOTITensorHandle attn_tensor_handle = nullptr; + + AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( + out_contents_ptr, + 4, // ndim + output_sizes.data(), + out_strides.data(), + 0, // storage_offset + dtype, + 2, // device_type (MPS) + 0, // device_index + &out_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + AOTITorchError create_attn_result = aoti_torch_create_tensor_from_blob_v2( + attn_contents_ptr, + 4, // ndim + attn_sizes.data(), + attn_strides.data(), + 0, // storage_offset + dtype, + 2, // device_type (MPS) + 0, // device_index + &attn_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_out_result != Error::Ok || create_attn_result != Error::Ok || + !out_tensor_handle || !attn_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create output tensors"); + aoti_torch_mps_free(out_contents_ptr); + aoti_torch_mps_free(attn_contents_ptr); + throw std::runtime_error("Failed to create output tensors"); + } + + // Mark that we own the memory for these tensors + auto* out_et_tensor = reinterpret_cast(out_tensor_handle); + auto* attn_et_tensor = reinterpret_cast(attn_tensor_handle); + is_tensor_own_memory[out_et_tensor] = true; + is_tensor_own_memory[attn_et_tensor] = true; + + // Set output tensor handles + *ret0 = out_tensor_handle; + *ret1 = attn_tensor_handle; + + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph implementation completed successfully"); + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch From ca5f1e52300560ba3dab33ed247afdb0ea36a30a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Sat, 11 Oct 2025 15:47:38 -0400 Subject: [PATCH 09/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 484158e9027..bc8c0483e9d 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -65,13 +65,12 @@ std::vector convert_strides_to_vector( std::vector strides(ndim); if (strides_ptr != nullptr) { - // Use provided strides. it is ok if provided strides here is not contiguous - // strides since it will be used internally in CUDA delegate. + // Use provided strides. for (int64_t i = 0; i < ndim; i++) { strides[i] = static_cast(strides_ptr[i]); } } else { - // Calculate strides from sizes using ExecutorTorch's algorithm + // Calculate strides from sizes. if (ndim > 0) { strides[ndim - 1] = static_cast( 1); // Last dimension has stride 1 From 81c45885a294719d96cef8d6484a4a36f647734f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 13 Oct 2025 18:47:10 -0400 Subject: [PATCH 10/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 2bda93e18a4..4b8b6cda4e0 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include // Ensure we have int64_t, int32_t definitions #include #include @@ -144,7 +143,8 @@ AOTITorchError aoti_torch_empty_strided( dtype); int64_t nbytes = numel * element_size; - if (device_type == 2) { // Metal/MPS + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + if (device_type == mps_device_type) { ptr = metal_allocate_buffer(nbytes); if (!ptr) { ET_LOG(Error, "Failed to allocate %lld bytes on Metal device", nbytes); From 422e4baf66b007ec529f78e44d4691ffd9e530ed Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 13 Oct 2025 19:37:57 -0400 Subject: [PATCH 11/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 4b8b6cda4e0..83250f308bb 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -143,7 +143,7 @@ AOTITorchError aoti_torch_empty_strided( dtype); int64_t nbytes = numel * element_size; - int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 if (device_type == mps_device_type) { ptr = metal_allocate_buffer(nbytes); if (!ptr) { From f46adc5149a168fec69d18bfebe7fcd39be37ce5 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Oct 2025 21:36:07 -0400 Subject: [PATCH 12/16] Update [ghstack-poisoned] --- backends/cuda/runtime/shims/memory.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 6fe315ba8ee..fe8ccf07281 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; using executorch::backends::aoti::aoti_torch_get_strides; +using executorch::backends::aoti::convert_sizes_to_vector; +using executorch::backends::aoti::convert_strides_to_vector; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; From 750badf4231a56295daf0c4f5f8ef934bf8d8a23 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 13:21:42 -0400 Subject: [PATCH 13/16] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 640 +++++++++++------- 1 file changed, 395 insertions(+), 245 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 0e547bc16db..0aa90650a1d 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -24,12 +24,45 @@ namespace backends { namespace metal { +using executorch::runtime::etensor::Tensor; + // Forward declaration of dispatch_sync_with_rethrow from et_metal.mm void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()); // Declare the global mapping from et_metal.mm extern std::unordered_map> ptr_to_mtl_buffer; +namespace { + +// Helper function to get Metal buffer from the global mapping +static id get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) { + void* data_ptr = tensor->mutable_data_ptr(); + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name); + throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping"); + } + return it->second; +} + +// Helper function to allocate a Metal buffer and register it in the global mapping. +static id allocate_mtl_buffer(void** data_ptr, size_t size_bytes) { + AOTITorchError malloc_err = aoti_torch_mps_malloc(data_ptr, size_bytes); + if (malloc_err != Error::Ok) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: Failed to allocate Metal buffer via aoti_torch_mps_malloc"); + throw std::runtime_error("Failed to allocate output Metal buffer"); + } + + auto it = ptr_to_mtl_buffer.find(*data_ptr); + if (it == ptr_to_mtl_buffer.end()) { + ET_LOG(Error, "allocate_and_register_mtl_buffer: aoti_torch_mps_malloc did not register buffer in map"); + throw std::runtime_error("Failed to look up allocated Metal buffer"); + } + return it->second; +} + +} // namespace + extern "C" { AOTITorchError aoti_torch_mps_mm_out( @@ -47,9 +80,9 @@ AOTITorchError aoti_torch_mps_mm_out( @autoreleasepool { try { // Convert AOTITensorHandle to ExecutorTorch tensors - auto out_tensor = reinterpret_cast(out); - auto self_tensor = reinterpret_cast(self); - auto mat2_tensor = reinterpret_cast(mat2); + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); ET_LOG(Debug, "aoti_torch_mps_mm_out: Converted tensor handles to ET tensors"); @@ -81,6 +114,25 @@ AOTITorchError aoti_torch_mps_mm_out( out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0); + // Check if mat2 is transposed (non-contiguous due to transpose) + // A transposed matrix will have stride(-2) == 1 (column-major instead of row-major) + // For a 2D tensor with shape [K, N]: + // - Contiguous (row-major): strides = [N, 1] + // - Transposed (column-major): strides = [1, K] + bool mat2_is_transposed = false; + int64_t mat2_stride_0 = mat2_tensor->strides()[0]; // stride for dimension 0 + int64_t mat2_stride_1 = mat2_tensor->strides()[1]; // stride for dimension 1 + + // Detect transposed layout: stride(-2) == 1 indicates column-major layout + if (mat2_stride_0 == 1 && mat2_stride_1 != 1) { + mat2_is_transposed = true; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is transposed (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } else { + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 is contiguous (strides=[%lld, %lld])", + mat2_stride_0, mat2_stride_1); + } + // Use the same dispatch pattern as other MPS operations for consistent synchronization ETMetalStream* stream = getCurrentMetalStream(); if (!stream) { @@ -95,32 +147,10 @@ AOTITorchError aoti_torch_mps_mm_out( throw std::runtime_error("Failed to get Metal device"); } - // Get Metal buffers from tensors using the global mapping - void* self_data_ptr = self_tensor->mutable_data_ptr(); - void* mat2_data_ptr = mat2_tensor->mutable_data_ptr(); - void* out_data_ptr = out_tensor->mutable_data_ptr(); - - // Look up Metal buffers from the global mapping - auto self_it = ptr_to_mtl_buffer.find(self_data_ptr); - auto mat2_it = ptr_to_mtl_buffer.find(mat2_data_ptr); - auto out_it = ptr_to_mtl_buffer.find(out_data_ptr); - - if (self_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_mm_out: self tensor not found in Metal buffer mapping"); - throw std::runtime_error("self tensor not found in Metal buffer mapping"); - } - if (mat2_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_mm_out: mat2 tensor not found in Metal buffer mapping"); - throw std::runtime_error("mat2 tensor not found in Metal buffer mapping"); - } - if (out_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_mm_out: out tensor not found in Metal buffer mapping"); - throw std::runtime_error("out tensor not found in Metal buffer mapping"); - } - - id self_buffer = self_it->second; - id mat2_buffer = mat2_it->second; - id out_buffer = out_it->second; + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_mm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_mm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_mm_out", "out"); ET_LOG(Debug, "aoti_torch_mps_mm_out: Using existing Metal buffers - self=%p, mat2=%p, out=%p", self_buffer, mat2_buffer, out_buffer); @@ -156,25 +186,56 @@ AOTITorchError aoti_torch_mps_mm_out( // Define tensor shapes for placeholders NSArray* selfShape = @[@(M), @(K)]; - NSArray* mat2Shape = @[@(K), @(N)]; NSArray* outShape = @[@(M), @(N)]; + // For mat2, we need to handle both contiguous and transposed cases + // If mat2 is transposed, its physical layout in memory is [N, K] (column-major) + // but logically we need [K, N] for the matrix multiplication + NSArray* mat2PhysicalShape; + if (mat2_is_transposed) { + // Physical shape reflects the actual memory layout (transposed) + mat2PhysicalShape = @[@(N), @(K)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (transposed): [%d,%d]", (int)N, (int)K); + } else { + // Physical shape is the logical shape (contiguous) + mat2PhysicalShape = @[@(K), @(N)]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: mat2 physical shape (contiguous): [%d,%d]", (int)K, (int)N); + } + ET_LOG(Debug, "aoti_torch_mps_mm_out: Creating placeholders with shapes self:[%d,%d] mat2:[%d,%d]", - (int)M, (int)K, (int)K, (int)N); + (int)M, (int)K, + mat2_is_transposed ? (int)N : (int)K, + mat2_is_transposed ? (int)K : (int)N); // Create placeholders for input tensors MPSGraphTensor* selfPlaceholder = [mpsGraph placeholderWithShape:selfShape dataType:mps_dtype name:@"self"]; - MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + MPSGraphTensor* mat2Placeholder = [mpsGraph placeholderWithShape:mat2PhysicalShape dataType:mps_dtype - name:@"mat2"]; + name:@"mat2_physical"]; ET_LOG(Debug, "aoti_torch_mps_mm_out: Created input placeholders"); - // Perform matrix multiplication using MPSGraph + // If mat2 is transposed, apply transpose operation in the graph to get the logical shape + MPSGraphTensor* mat2Logical; + if (mat2_is_transposed) { + // Transpose from physical [N, K] to logical [K, N] + // MPSGraph transposeTensor swaps the last two dimensions for 2D tensors + mat2Logical = [mpsGraph transposeTensor:mat2Placeholder + dimension:-2 + withDimension:-1 + name:@"mat2_transposed"]; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Applied transpose operation to mat2 in graph"); + } else { + // No transpose needed, use placeholder directly + mat2Logical = mat2Placeholder; + ET_LOG(Debug, "aoti_torch_mps_mm_out: Using mat2 placeholder directly (no transpose needed)"); + } + + // Perform matrix multiplication using MPSGraph with the logical mat2 tensor MPSGraphTensor* mmOutput = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder - secondaryTensor:mat2Placeholder + secondaryTensor:mat2Logical name:@"matrix_multiplication"]; ET_LOG(Debug, "aoti_torch_mps_mm_out: Successfully created matrix multiplication tensor"); @@ -183,17 +244,18 @@ AOTITorchError aoti_torch_mps_mm_out( NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; // Create MPSGraphTensorData objects for input tensors + // Use physical shapes to match how data is actually laid out in memory MPSGraphTensorData* selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer shape:selfShape dataType:mps_dtype]; MPSGraphTensorData* mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer - shape:mat2Shape + shape:mat2PhysicalShape dataType:mps_dtype]; feeds[selfPlaceholder] = selfData; feeds[mat2Placeholder] = mat2Data; - ET_LOG(Debug, "aoti_torch_mps_mm_out: Created feeds dictionary"); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Created feeds dictionary with physical shapes"); // Create results dictionary MPSGraphTensorData* outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer @@ -217,6 +279,7 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + ET_LOG(Debug, "aoti_torch_mps_mm_out: Executed successfully"); return Error::Ok; } catch (const std::exception& e) { @@ -255,13 +318,13 @@ AOTITorchError aoti_torch_mps_convolution( @autoreleasepool { try { // Convert AOTITensorHandle to ExecutorTorch tensors - auto input_tensor = reinterpret_cast(input); - auto weight_tensor = reinterpret_cast(weight); + auto input_tensor = reinterpret_cast(input); + auto weight_tensor = reinterpret_cast(weight); // bias can be null for convolutions without bias - executorch::runtime::etensor::Tensor* bias_tensor = nullptr; + Tensor* bias_tensor = nullptr; if (bias && *bias) { - bias_tensor = reinterpret_cast(*bias); + bias_tensor = reinterpret_cast(*bias); ET_LOG(Debug, "aoti_torch_mps_convolution: Has bias tensor"); } else { ET_LOG(Debug, "aoti_torch_mps_convolution: No bias tensor"); @@ -408,29 +471,6 @@ AOTITorchError aoti_torch_mps_convolution( throw std::runtime_error("Failed to get Metal device"); } - // Get Metal buffers from tensors using the global mapping - void* input_data_ptr = input_tensor->mutable_data_ptr(); - void* weight_data_ptr = weight_tensor->mutable_data_ptr(); - - // Look up Metal buffers from the global mapping - auto input_it = ptr_to_mtl_buffer.find(input_data_ptr); - auto weight_it = ptr_to_mtl_buffer.find(weight_data_ptr); - - if (input_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_convolution: input tensor not found in Metal buffer mapping"); - throw std::runtime_error("input tensor not found in Metal buffer mapping"); - } - if (weight_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_convolution: weight tensor not found in Metal buffer mapping"); - throw std::runtime_error("weight tensor not found in Metal buffer mapping"); - } - - id input_buffer = input_it->second; - id weight_buffer = weight_it->second; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Using existing Metal buffers - input=%p, weight=%p", - input_buffer, weight_buffer); - // End any existing kernel coalescing to ensure a clean state for MPS stream->endKernelCoalescing(); @@ -541,33 +581,30 @@ AOTITorchError aoti_torch_mps_convolution( if (bias_tensor) { ET_LOG(Debug, "aoti_torch_mps_convolution: Adding bias to convolution output"); - // Get bias tensor data - void* bias_data_ptr = bias_tensor->mutable_data_ptr(); - auto bias_it = ptr_to_mtl_buffer.find(bias_data_ptr); + // Create bias placeholder + NSArray* biasShape = @[@(C_out)]; + biasPlaceholder = [mpsGraph placeholderWithShape:biasShape + dataType:mps_dtype + name:@"bias"]; - if (bias_it != ptr_to_mtl_buffer.end()) { - id bias_buffer = bias_it->second; + // Add bias to convolution output + finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput + secondaryTensor:biasPlaceholder + name:@"add_bias"]; - // Create bias placeholder - NSArray* biasShape = @[@(C_out)]; - biasPlaceholder = [mpsGraph placeholderWithShape:biasShape - dataType:mps_dtype - name:@"bias"]; - - // Add bias to convolution output - finalOutput = [mpsGraph additionWithPrimaryTensor:convOutput - secondaryTensor:biasPlaceholder - name:@"add_bias"]; - - ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); - } else { - ET_LOG(Debug, "aoti_torch_mps_convolution: Bias tensor not found in Metal buffer mapping, skipping bias"); - } + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias placeholder to graph"); } // Create feeds dictionary for graph execution NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + // Get Metal buffers from tensors + id input_buffer = get_mtl_buffer(input_tensor, "aoti_torch_mps_convolution", "input"); + id weight_buffer = get_mtl_buffer(weight_tensor, "aoti_torch_mps_convolution", "weight"); + + ET_LOG(Debug, "aoti_torch_mps_convolution: Using existing Metal buffers - input=%p, weight=%p", + input_buffer, weight_buffer); + // Create MPSGraphTensorData objects for input tensors MPSGraphTensorData* inputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:input_buffer shape:inputShape @@ -581,38 +618,23 @@ AOTITorchError aoti_torch_mps_convolution( // Add bias data to feeds if provided if (bias_tensor && biasPlaceholder) { - void* bias_data_ptr = bias_tensor->mutable_data_ptr(); - auto bias_it = ptr_to_mtl_buffer.find(bias_data_ptr); - - if (bias_it != ptr_to_mtl_buffer.end()) { - id bias_buffer = bias_it->second; - NSArray* biasShape = @[@(C_out)]; - MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer - shape:biasShape - dataType:mps_dtype]; + id bias_buffer = get_mtl_buffer(bias_tensor, "aoti_torch_mps_convolution", "bias"); - feeds[biasPlaceholder] = biasData; - ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); - } + NSArray* biasShape = @[@(C_out)]; + MPSGraphTensorData* biasData = [[MPSGraphTensorData alloc] initWithMTLBuffer:bias_buffer + shape:biasShape + dataType:mps_dtype]; + + feeds[biasPlaceholder] = biasData; + ET_LOG(Debug, "aoti_torch_mps_convolution: Added bias tensor to feeds"); } ET_LOG(Debug, "aoti_torch_mps_convolution: Created feeds dictionary"); - // Create or reuse output Metal buffer via AOTI API; keeps GPU residency + // Create Metal buffer for output tensor size_t output_size_bytes = N * C_out * H_out * W_out * element_size; void* output_contents_ptr = nullptr; - AOTITorchError malloc_err = aoti_torch_mps_malloc(&output_contents_ptr, output_size_bytes); - if (malloc_err != Error::Ok || !output_contents_ptr) { - ET_LOG(Error, "aoti_torch_mps_convolution: Failed to allocate Metal buffer via aoti_torch_mps_malloc"); - throw std::runtime_error("Failed to allocate output Metal buffer"); - } - - auto out_it = ptr_to_mtl_buffer.find(output_contents_ptr); - if (out_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps_convolution: aoti_torch_mps_malloc did not register buffer in map"); - throw std::runtime_error("Failed to look up allocated Metal buffer"); - } - id output_buffer = out_it->second; + id output_buffer = allocate_mtl_buffer(&output_contents_ptr, output_size_bytes); // Create results dictionary (MPSGraph output is 4D) NSArray* outputShape = @[@(N), @(C_out), @(H_out), @(W_out)]; @@ -633,11 +655,10 @@ AOTITorchError aoti_torch_mps_convolution( ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); throw std::runtime_error("MPSGraph execution failed with NSException"); + } @catch (...) { + ET_LOG(Error, "aoti_torch_mps_convolution: MPSGraph execution failed"); + throw std::runtime_error("MPSGraph execution failed"); } - // } @catch (const std::exception& e) { - // ET_LOG(Error, "aoti_torch_mps_convolution exception: %s", e.what()); - // throw std::runtime_error("MPSGraph execution failed"); - // } ET_LOG(Debug, "aoti_torch_mps_convolution: MPSGraph execution completed successfully"); @@ -705,7 +726,7 @@ AOTITorchError aoti_torch_mps_convolution( } // Verify the tensor was created with the correct size - auto* et_tensor = reinterpret_cast(output_tensor_handle); + auto* et_tensor = reinterpret_cast(output_tensor_handle); size_t actual_numel = et_tensor->numel(); size_t expected_numel = static_cast(N * C_out * H_out * W_out); @@ -721,6 +742,7 @@ AOTITorchError aoti_torch_mps_convolution( ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel); + ET_LOG(Debug, "aoti_torch_mps_convolution: Executed successfully"); return Error::Ok; } catch (const std::exception& e) { @@ -762,9 +784,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( try { @autoreleasepool { // Convert AOTITensorHandle to ExecutorTorch tensors - auto* query_tensor = reinterpret_cast(query); - auto* key_tensor = reinterpret_cast(key); - auto* value_tensor = reinterpret_cast(value); + auto* query_tensor = reinterpret_cast(query); + auto* key_tensor = reinterpret_cast(key); + auto* value_tensor = reinterpret_cast(value); ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Converted tensor handles to ET tensors"); @@ -788,6 +810,109 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: batchSize=%lld, num_heads=%lld, qSize=%lld, headSize=%lld, kvSeqLength=%lld", batchSize, num_heads, qSize, headSize, kvSeqLength); + // Detect non-contiguous layouts for query, key, and value tensors + // For a 4D tensor [batch, num_heads, seq_len, head_dim], common non-contiguous patterns: + // - Transposed last 2 dims (dims 2,3): strides[2] == 1 && strides[3] == seq_len (seq_len and head_dim swapped) + // - Transposed internal dims (dims 1,2): strides[1] == head_dim && strides[2] == num_heads*head_dim (num_heads and seq_len swapped) + // - Other permutations may exist depending on upstream operations + + bool query_is_transposed_last2 = false; // transpose of dims -2 and -1 + bool query_is_transposed_internal = false; // transpose of dims 1 and 2 + bool key_is_transposed_last2 = false; + bool key_is_transposed_internal = false; + bool value_is_transposed_last2 = false; + bool value_is_transposed_internal = false; + + // Expected contiguous strides for query [batch, num_heads, qSize, headSize] + int64_t expected_q_stride_3 = 1; + int64_t expected_q_stride_2 = headSize; + int64_t expected_q_stride_1 = qSize * headSize; + int64_t expected_q_stride_0 = num_heads * qSize * headSize; + + // Check query tensor layout + auto q_strides = query_tensor->strides(); + if (q_strides[3] != expected_q_stride_3 || q_strides[2] != expected_q_stride_2 || + q_strides[1] != expected_q_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (q_strides[2] == 1 && q_strides[3] == qSize && q_strides[1] == qSize * headSize) { + query_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (q_strides[1] == headSize && q_strides[2] == num_heads * headSize && q_strides[3] == 1) { + query_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)q_strides[0], (int64_t)q_strides[1], (int64_t)q_strides[2], (int64_t)q_strides[3]); + } + + // Expected contiguous strides for key [batch, num_heads, kvSeqLength, headSize] + int64_t expected_k_stride_3 = 1; + int64_t expected_k_stride_2 = headSize; + int64_t expected_k_stride_1 = kvSeqLength * headSize; + int64_t expected_k_stride_0 = num_heads * kvSeqLength * headSize; + + // Check key tensor layout + auto k_strides = key_tensor->strides(); + if (k_strides[3] != expected_k_stride_3 || k_strides[2] != expected_k_stride_2 || + k_strides[1] != expected_k_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (k_strides[2] == 1 && k_strides[3] == kvSeqLength && k_strides[1] == kvSeqLength * headSize) { + key_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (k_strides[1] == headSize && k_strides[2] == num_heads * headSize && k_strides[3] == 1) { + key_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)k_strides[0], (int64_t)k_strides[1], (int64_t)k_strides[2], (int64_t)k_strides[3]); + } + + // Expected contiguous strides for value [batch, num_heads, kvSeqLength, headSize] + int64_t expected_v_stride_3 = 1; + int64_t expected_v_stride_2 = headSize; + int64_t expected_v_stride_1 = kvSeqLength * headSize; + int64_t expected_v_stride_0 = num_heads * kvSeqLength * headSize; + + // Check value tensor layout + auto v_strides = value_tensor->strides(); + if (v_strides[3] != expected_v_stride_3 || v_strides[2] != expected_v_stride_2 || + v_strides[1] != expected_v_stride_1) { + // Check if it's a transpose of the last two dimensions (dims 2 and 3) + if (v_strides[2] == 1 && v_strides[3] == kvSeqLength && v_strides[1] == kvSeqLength * headSize) { + value_is_transposed_last2 = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed last 2 dims (dims 2,3) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + // Check if it's a transpose of the internal dimensions (dims 1 and 2) + else if (v_strides[1] == headSize && v_strides[2] == num_heads * headSize && v_strides[3] == 1) { + value_is_transposed_internal = true; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor has transposed internal dims (dims 1,2) (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is non-contiguous with unusual layout (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + } else { + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value tensor is contiguous (strides=[%lld,%lld,%lld,%lld])", + (int64_t)v_strides[0], (int64_t)v_strides[1], (int64_t)v_strides[2], (int64_t)v_strides[3]); + } + // Determine data type and element size int32_t dtype = static_cast(query_tensor->scalar_type()); MPSDataType mps_dtype; @@ -823,63 +948,10 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( throw std::runtime_error("Failed to get Metal device"); } - // Get Metal buffers for input tensors - void* query_data_ptr = query_tensor->mutable_data_ptr(); - void* key_data_ptr = key_tensor->mutable_data_ptr(); - void* value_data_ptr = value_tensor->mutable_data_ptr(); - - id query_buffer = nullptr; - id key_buffer = nullptr; - id value_buffer = nullptr; - - // Look up Metal buffers from the global mapping - auto query_it = ptr_to_mtl_buffer.find(query_data_ptr); - auto key_it = ptr_to_mtl_buffer.find(key_data_ptr); - auto value_it = ptr_to_mtl_buffer.find(value_data_ptr); - - if (query_it != ptr_to_mtl_buffer.end()) { - query_buffer = query_it->second; - } - if (key_it != ptr_to_mtl_buffer.end()) { - key_buffer = key_it->second; - } - if (value_it != ptr_to_mtl_buffer.end()) { - value_buffer = value_it->second; - } - - // Create temporary Metal buffers if not found in mapping - if (!query_buffer) { - size_t query_size = query_tensor->numel() * element_size; - query_buffer = [device newBufferWithBytes:query_data_ptr - length:query_size - options:MTLResourceStorageModeShared]; - if (!query_buffer) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for query tensor"); - throw std::runtime_error("Failed to create Metal buffer for query tensor"); - } - } - - if (!key_buffer) { - size_t key_size = key_tensor->numel() * element_size; - key_buffer = [device newBufferWithBytes:key_data_ptr - length:key_size - options:MTLResourceStorageModeShared]; - if (!key_buffer) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for key tensor"); - throw std::runtime_error("Failed to create Metal buffer for key tensor"); - } - } - - if (!value_buffer) { - size_t value_size = value_tensor->numel() * element_size; - value_buffer = [device newBufferWithBytes:value_data_ptr - length:value_size - options:MTLResourceStorageModeShared]; - if (!value_buffer) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for value tensor"); - throw std::runtime_error("Failed to create Metal buffer for value tensor"); - } - } + // Get Metal buffers for query, key and value tensors + id query_buffer = get_mtl_buffer(query_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "query"); + id key_buffer = get_mtl_buffer(key_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "key"); + id value_buffer = get_mtl_buffer(value_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "value"); // Calculate output tensor dimensions std::vector output_sizes = {batchSize, num_heads, qSize, headSize}; @@ -905,34 +977,10 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( size_t attn_size_bytes = batchSize * num_heads * qSize * kvSeqLength * element_size; void* out_contents_ptr = nullptr; - AOTITorchError out_malloc_err = aoti_torch_mps_malloc(&out_contents_ptr, out_size_bytes); - if (out_malloc_err != Error::Ok || !out_contents_ptr) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to allocate out buffer via aoti_torch_mps_malloc"); - throw std::runtime_error("Failed to allocate output buffer"); - } - auto out_map_it = ptr_to_mtl_buffer.find(out_contents_ptr); - if (out_map_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: out buffer not found in mapping after malloc"); - aoti_torch_mps_free(out_contents_ptr); - throw std::runtime_error("Mapping for out buffer missing"); - } - id out_buffer = out_map_it->second; + id out_buffer = allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); void* attn_contents_ptr = nullptr; - AOTITorchError attn_malloc_err = aoti_torch_mps_malloc(&attn_contents_ptr, attn_size_bytes); - if (attn_malloc_err != Error::Ok || !attn_contents_ptr) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to allocate attn buffer via aoti_torch_mps_malloc"); - aoti_torch_mps_free(out_contents_ptr); - throw std::runtime_error("Failed to allocate attn buffer"); - } - auto attn_map_it = ptr_to_mtl_buffer.find(attn_contents_ptr); - if (attn_map_it == ptr_to_mtl_buffer.end()) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: attn buffer not found in mapping after malloc"); - aoti_torch_mps_free(out_contents_ptr); - aoti_torch_mps_free(attn_contents_ptr); - throw std::runtime_error("Mapping for attn buffer missing"); - } - id attn_weights_buffer = attn_map_it->second; + id attn_weights_buffer = allocate_mtl_buffer(&attn_contents_ptr, attn_size_bytes); // End any existing kernel coalescing to ensure a clean state for MPS stream->endKernelCoalescing(); @@ -953,32 +1001,146 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( MPSGraph* mpsGraph = [MPSGraph new]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraph instance"); - // Define tensor shapes for placeholders - NSArray* queryShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; - NSArray* keyShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; - NSArray* valueShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + // Define physical tensor shapes for placeholders (matching actual memory layout) + // Two transpose patterns supported: + // 1. Last 2 dims transposed (dims 2,3): [batch, num_heads, head_dim, seq_len] + // 2. Internal dims transposed (dims 1,2): [batch, seq_len, num_heads, head_dim] + NSArray* queryPhysicalShape; + NSArray* keyPhysicalShape; + NSArray* valuePhysicalShape; + + if (query_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, qSize] (dims 2,3 swapped) + queryPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(qSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)qSize); + } else if (query_is_transposed_internal) { + // Physical layout: [batch, qSize, num_heads, headSize] (dims 1,2 swapped) + queryPhysicalShape = @[@(batchSize), @(qSize), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)qSize, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, qSize, headSize] + queryPhysicalShape = @[@(batchSize), @(num_heads), @(qSize), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Query physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)qSize, (int)headSize); + } + + if (key_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) + keyPhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); + } else if (key_is_transposed_internal) { + // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) + keyPhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] + keyPhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Key physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + } - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Creating placeholders with shapes Q:[%d,%d,%d,%d] K:[%d,%d,%d,%d] V:[%d,%d,%d,%d]", - (int)batchSize, (int)num_heads, (int)qSize, (int)headSize, - (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize, - (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + if (value_is_transposed_last2) { + // Physical layout: [batch, num_heads, headSize, kvSeqLength] (dims 2,3 swapped) + valuePhysicalShape = @[@(batchSize), @(num_heads), @(headSize), @(kvSeqLength)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 2,3): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)headSize, (int)kvSeqLength); + } else if (value_is_transposed_internal) { + // Physical layout: [batch, kvSeqLength, num_heads, headSize] (dims 1,2 swapped) + valuePhysicalShape = @[@(batchSize), @(kvSeqLength), @(num_heads), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (transposed dims 1,2): [%d,%d,%d,%d]", + (int)batchSize, (int)kvSeqLength, (int)num_heads, (int)headSize); + } else { + // Physical layout matches logical layout: [batch, num_heads, kvSeqLength, headSize] + valuePhysicalShape = @[@(batchSize), @(num_heads), @(kvSeqLength), @(headSize)]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Value physical shape (contiguous): [%d,%d,%d,%d]", + (int)batchSize, (int)num_heads, (int)kvSeqLength, (int)headSize); + } - // Create placeholders for input tensors - MPSGraphTensor* queryPlaceholder = [mpsGraph placeholderWithShape:queryShape + // Create placeholders for input tensors with physical shapes + MPSGraphTensor* queryPlaceholder = [mpsGraph placeholderWithShape:queryPhysicalShape dataType:mps_dtype - name:@"query"]; + name:@"query_physical"]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created query placeholder"); - MPSGraphTensor* keyPlaceholder = [mpsGraph placeholderWithShape:keyShape + MPSGraphTensor* keyPlaceholder = [mpsGraph placeholderWithShape:keyPhysicalShape dataType:mps_dtype - name:@"key"]; + name:@"key_physical"]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created key placeholder"); - MPSGraphTensor* valuePlaceholder = [mpsGraph placeholderWithShape:valueShape + MPSGraphTensor* valuePlaceholder = [mpsGraph placeholderWithShape:valuePhysicalShape dataType:mps_dtype - name:@"value"]; + name:@"value_physical"]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created value placeholder"); + // Apply transpose operations in the graph to convert physical to logical layout + // Logical shapes needed for SDPA: Q[batch, num_heads, qSize, headSize], + // K[batch, num_heads, kvSeqLength, headSize], + // V[batch, num_heads, kvSeqLength, headSize] + MPSGraphTensor* queryLogical; + MPSGraphTensor* keyLogical; + MPSGraphTensor* valueLogical; + + if (query_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, qSize] → [batch, num_heads, qSize, headSize] + queryLogical = [mpsGraph transposeTensor:queryPlaceholder + dimension:-2 + withDimension:-1 + name:@"query_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to query tensor in graph"); + } else if (query_is_transposed_internal) { + // Transpose dims 1,2: [batch, qSize, num_heads, headSize] → [batch, num_heads, qSize, headSize] + queryLogical = [mpsGraph transposeTensor:queryPlaceholder + dimension:1 + withDimension:2 + name:@"query_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to query tensor in graph"); + } else { + queryLogical = queryPlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using query placeholder directly (no transpose needed)"); + } + + if (key_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] + keyLogical = [mpsGraph transposeTensor:keyPlaceholder + dimension:-2 + withDimension:-1 + name:@"key_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to key tensor in graph"); + } else if (key_is_transposed_internal) { + // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] + keyLogical = [mpsGraph transposeTensor:keyPlaceholder + dimension:1 + withDimension:2 + name:@"key_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to key tensor in graph"); + } else { + keyLogical = keyPlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using key placeholder directly (no transpose needed)"); + } + + if (value_is_transposed_last2) { + // Transpose dims 2,3: [batch, num_heads, headSize, kvSeqLength] → [batch, num_heads, kvSeqLength, headSize] + valueLogical = [mpsGraph transposeTensor:valuePlaceholder + dimension:-2 + withDimension:-1 + name:@"value_transposed_last2"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 2,3) to value tensor in graph"); + } else if (value_is_transposed_internal) { + // Transpose dims 1,2: [batch, kvSeqLength, num_heads, headSize] → [batch, num_heads, kvSeqLength, headSize] + valueLogical = [mpsGraph transposeTensor:valuePlaceholder + dimension:1 + withDimension:2 + name:@"value_transposed_internal"]; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Applied transpose (dims 1,2) to value tensor in graph"); + } else { + valueLogical = valuePlaceholder; + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Using value placeholder directly (no transpose needed)"); + } + MPSGraphTensor* maskTensor = nil; // Handle causal mask @@ -1022,7 +1184,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( // Handle explicit attention mask if provided MPSGraphTensor* explicitMaskPlaceholder = nil; if (attn_mask && *attn_mask) { - auto* mask_tensor = reinterpret_cast(*attn_mask); + auto* mask_tensor = reinterpret_cast(*attn_mask); ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Adding explicit attention mask"); @@ -1047,12 +1209,13 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created explicit mask placeholder"); } - // Perform scaled dot product attention using MPSGraph + // Perform scaled dot product attention using MPSGraph with logical (possibly transposed) tensors + // The logical tensors have the correct shapes for attention computation regardless of input memory layout ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Calling scaledDotProductAttentionWithQueryTensor with scale=%f", scale_factor); - MPSGraphTensor* outputTensor = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryPlaceholder - keyTensor:keyPlaceholder - valueTensor:valuePlaceholder + MPSGraphTensor* outputTensor = [mpsGraph scaledDotProductAttentionWithQueryTensor:queryLogical + keyTensor:keyLogical + valueTensor:valueLogical maskTensor:maskTensor scale:scale_factor name:@"scaled_dot_product_attention"]; @@ -1062,17 +1225,18 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created feeds dictionary"); - // Create MPSGraphTensorData objects for input tensors + // Create MPSGraphTensorData objects for input tensors using physical shapes + // Physical shapes match the actual memory layout of the tensors MPSGraphTensorData* queryData = [[MPSGraphTensorData alloc] initWithMTLBuffer:query_buffer - shape:queryShape + shape:queryPhysicalShape dataType:mps_dtype]; MPSGraphTensorData* keyData = [[MPSGraphTensorData alloc] initWithMTLBuffer:key_buffer - shape:keyShape + shape:keyPhysicalShape dataType:mps_dtype]; MPSGraphTensorData* valueData = [[MPSGraphTensorData alloc] initWithMTLBuffer:value_buffer - shape:valueShape + shape:valuePhysicalShape dataType:mps_dtype]; - ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraphTensorData objects for inputs"); + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Created MPSGraphTensorData objects with physical shapes"); feeds[queryPlaceholder] = queryData; feeds[keyPlaceholder] = keyData; @@ -1081,24 +1245,9 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( // Add explicit mask data to feeds if provided if (explicitMaskPlaceholder && attn_mask && *attn_mask) { - auto* mask_tensor = reinterpret_cast(*attn_mask); - void* mask_data_ptr = mask_tensor->mutable_data_ptr(); - - // Get or create Metal buffer for mask - id mask_buffer = nullptr; - auto mask_it = ptr_to_mtl_buffer.find(mask_data_ptr); - if (mask_it != ptr_to_mtl_buffer.end()) { - mask_buffer = mask_it->second; - } else { - size_t mask_size = mask_tensor->numel() * element_size; - mask_buffer = [device newBufferWithBytes:mask_data_ptr - length:mask_size - options:MTLResourceStorageModeShared]; - if (!mask_buffer) { - ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Failed to create Metal buffer for attention mask"); - throw std::runtime_error("Failed to create Metal buffer for attention mask"); - } - } + auto* mask_tensor = reinterpret_cast(*attn_mask); + // Get Metal buffer for mask + id mask_buffer = get_mtl_buffer(mask_tensor, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps", "mask"); NSMutableArray* maskShapeArray = [NSMutableArray array]; for (int i = 0; i < mask_tensor->dim(); i++) { @@ -1178,8 +1327,8 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( } // Mark that we own the memory for these tensors - auto* out_et_tensor = reinterpret_cast(out_tensor_handle); - auto* attn_et_tensor = reinterpret_cast(attn_tensor_handle); + auto* out_et_tensor = reinterpret_cast(out_tensor_handle); + auto* attn_et_tensor = reinterpret_cast(attn_tensor_handle); is_tensor_own_memory[out_et_tensor] = true; is_tensor_own_memory[attn_et_tensor] = true; @@ -1190,6 +1339,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: MPSGraph implementation completed successfully"); } + ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Executed successfully"); return Error::Ok; } catch (const std::exception& e) { From 71f87b691d65e222957b714e40e7d2e657075cb8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:32:04 -0400 Subject: [PATCH 14/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 107 ++++++++++++++++++ .../apple/metal/runtime/shims/et_metal.mm | 22 +++- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index c18ad513a3a..a1c8c684131 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -77,6 +77,35 @@ enum class SyncType { // ======================= // ETMetalShaderLibrary - ExecuTorch Metal shader library management // ======================= + +/** + * @class ETMetalShaderLibrary + * @brief Manages Metal shader library compilation and kernel function retrieval. + * + * This class provides a high-level interface for compiling Metal shading language + * source code into a Metal library and creating compute pipeline states for + * kernel functions. It handles the creation and caching of Metal compute pipeline + * states and functions, which should be reused across multiple kernel dispatches. + * + * The class automatically compiles the provided shader source code upon construction + * and maintains an internal cache of compute pipeline states for different kernel + * functions to avoid redundant compilation. + * + * Example usage: + * @code + * std::string shaderSource = R"( + * #include + * using namespace metal; + * kernel void my_kernel(device float* data [[buffer(0)]], + * uint tid [[thread_position_in_grid]]) { + * data[tid] = data[tid] * 2.0; + * } + * )"; + * + * ETMetalShaderLibrary library(shaderSource); + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * @endcode + */ class ETMetalShaderLibrary { public: ETMetalShaderLibrary(const std::string& source); @@ -103,6 +132,45 @@ class ETMetalShaderLibrary { // ======================= // ETMetalKernelFunction - ExecuTorch Metal kernel function execution // ======================= + +/** + * @class ETMetalKernelFunction + * @brief Represents a Metal compute kernel function ready for execution. + * + * This class encapsulates a Metal compute pipeline state and function, providing + * a high-level interface for setting kernel arguments and dispatching compute + * work to the GPU. It handles the encoding of compute commands and manages the + * interaction with Metal's compute command encoder. + * + * The class supports different dispatch patterns: + * - Single-dimension dispatch for linear workloads + * - Multi-dimensional dispatch for grid-based workloads + * - Custom thread group sizes for performance optimization + * + * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) + * or scalar values. The class handles the encoding of these arguments + * into the compute command encoder. + * + * Example usage: + * @code + * // Get kernel function from library + * auto kernelFunction = library.getKernelFunction("vector_add"); + * + * // Start encoding commands + * kernelFunction->startEncoding(); + * + * // Set tensor arguments + * kernelFunction->setArg(0, inputTensorA); + * kernelFunction->setArg(1, inputTensorB); + * kernelFunction->setArg(2, outputTensor); + * + * // Set scalar argument + * kernelFunction->setArg(3, static_cast(numElements)); + * + * // Dispatch for linear workload + * kernelFunction->dispatchSingle(numElements); + * @endcode + */ class ETMetalKernelFunction { public: ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); @@ -132,6 +200,45 @@ class ETMetalKernelFunction { // ======================= // ETMetalStream - Metal command buffer and synchronization management // ======================= + +/** + * @class ETMetalStream + * @brief Manages Metal compute command streams and provides GPU synchronization. + * + * This class serves as the central management hub for Metal GPU operations, providing + * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, + * compute command encoder management, and various synchronization patterns required for + * efficient GPU computation. + * + * Key features: + * - Lazy command buffer and encoder creation for optimal resource usage + * - Thread-safe operations using serial dispatch queues + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Kernel coalescing to batch multiple operations efficiently + * - MPSGraph integration for high-level neural network operations + * - Memory operations (copy, fill) with GPU acceleration via blit encoders + * + * The stream follows PyTorch's MPS stream design patterns, providing similar semantics + * for command buffer management and synchronization. + * + * Example usage: + * @code + * // Get current stream (typically the default stream) + * ETMetalStream* stream = getCurrentMetalStream(); + * + * // Execute kernel operations (handled automatically) + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * kernelFunction->startEncoding(); + * kernelFunction->setArg(0, inputTensor); + * kernelFunction->dispatchSingle(numElements); + * + * // Synchronize to ensure completion + * stream->synchronize(SyncType::COMMIT_AND_WAIT); + * + * // Copy between GPU buffers using blit encoder + * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT); + * @endcode + */ class ETMetalStream { public: ETMetalStream(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index 5afcf761d56..f76146ab783 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -743,6 +743,26 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) { + + if (length == 0) { + return; + + // Check that offsets are within buffer bounds before copying + if (!srcBuffer || !dstBuffer) { + ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil"); + return; + } + NSUInteger srcBufferLength = [srcBuffer length]; + NSUInteger dstBufferLength = [dstBuffer length]; + if (srcOffset + length > srcBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength); + return; + } + if (dstOffset + length > dstBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength); + return; + } + dispatch_sync(serialQueue_, ^{ @autoreleasepool { endKernelCoalescing(); @@ -792,8 +812,6 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev targetOperations:nil resultsDictionary:results executionDescriptor:nil]; - - //synchronize(syncType); } }); } From 95a70247fded8a432c69a08d64e9d891a2c8a2f4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:40:46 -0400 Subject: [PATCH 15/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index a1c8c684131..75f79e5139c 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -80,16 +80,18 @@ enum class SyncType { /** * @class ETMetalShaderLibrary - * @brief Manages Metal shader library compilation and kernel function retrieval. + * @brief Manages Metal shader library compilation and kernel function + * retrieval. * - * This class provides a high-level interface for compiling Metal shading language - * source code into a Metal library and creating compute pipeline states for - * kernel functions. It handles the creation and caching of Metal compute pipeline - * states and functions, which should be reused across multiple kernel dispatches. + * This class provides a high-level interface for compiling Metal shading + * language source code into a Metal library and creating compute pipeline + * states for kernel functions. It handles the creation and caching of Metal + * compute pipeline states and functions, which should be reused across multiple + * kernel dispatches. * - * The class automatically compiles the provided shader source code upon construction - * and maintains an internal cache of compute pipeline states for different kernel - * functions to avoid redundant compilation. + * The class automatically compiles the provided shader source code upon + * construction and maintains an internal cache of compute pipeline states for + * different kernel functions to avoid redundant compilation. * * Example usage: * @code @@ -137,18 +139,18 @@ class ETMetalShaderLibrary { * @class ETMetalKernelFunction * @brief Represents a Metal compute kernel function ready for execution. * - * This class encapsulates a Metal compute pipeline state and function, providing - * a high-level interface for setting kernel arguments and dispatching compute - * work to the GPU. It handles the encoding of compute commands and manages the - * interaction with Metal's compute command encoder. + * This class encapsulates a Metal compute pipeline state and function, + * providing a high-level interface for setting kernel arguments and dispatching + * compute work to the GPU. It handles the encoding of compute commands and + * manages the interaction with Metal's compute command encoder. * * The class supports different dispatch patterns: * - Single-dimension dispatch for linear workloads * - Multi-dimensional dispatch for grid-based workloads * - Custom thread group sizes for performance optimization * - * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) - * or scalar values. The class handles the encoding of these arguments + * Kernel arguments can be set using tensors (which will be mapped to Metal + * buffers) or scalar values. The class handles the encoding of these arguments * into the compute command encoder. * * Example usage: @@ -203,23 +205,25 @@ class ETMetalKernelFunction { /** * @class ETMetalStream - * @brief Manages Metal compute command streams and provides GPU synchronization. + * @brief Manages Metal compute command streams and provides GPU + * synchronization. * - * This class serves as the central management hub for Metal GPU operations, providing - * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, - * compute command encoder management, and various synchronization patterns required for - * efficient GPU computation. + * This class serves as the central management hub for Metal GPU operations, + * providing a stream-based abstraction similar to CUDA streams. It handles + * command buffer lifecycle, compute command encoder management, and various + * synchronization patterns required for efficient GPU computation. * * Key features: * - Lazy command buffer and encoder creation for optimal resource usage * - Thread-safe operations using serial dispatch queues - * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, + * COMMIT_AND_CONTINUE, etc.) * - Kernel coalescing to batch multiple operations efficiently - * - MPSGraph integration for high-level neural network operations + * - MPSGraph integration for executing fall back operations (mm, conv, sdpa) * - Memory operations (copy, fill) with GPU acceleration via blit encoders * - * The stream follows PyTorch's MPS stream design patterns, providing similar semantics - * for command buffer management and synchronization. + * The stream follows PyTorch's MPS stream design patterns, providing similar + * semantics for command buffer management and synchronization. * * Example usage: * @code From d37e7efb19229e57e5dd314c5c06345bb68f4c59 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:51:31 -0400 Subject: [PATCH 16/16] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index f76146ab783..fdca0a28cf3 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -746,6 +746,7 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev if (length == 0) { return; + } // Check that offsets are within buffer bounds before copying if (!srcBuffer || !dstBuffer) {