From 64207122dbd838fa9d84a6b3f54768d45bf41fc0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 13:29:55 -0400 Subject: [PATCH 01/14] Update [ghstack-poisoned] --- backends/aoti/aoti_model_container.cpp | 6 ++++++ backends/aoti/aoti_model_container.h | 16 ++++++++++++++++ backends/aoti/common_shims.cpp | 5 +++++ backends/aoti/common_shims.h | 3 +++ 4 files changed, 30 insertions(+) diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 03be835a0c3..d1764451ab6 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +// Global function pointers needed by Metal backend +AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName = nullptr; +AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants = nullptr; + } // extern "C" } // namespace aoti diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 9b185327172..88d936d21ba 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +// Function pointer types needed by Metal backend +using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** input_name); + +using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants); + +// Global function pointers needed by Metal backend +extern AOTInductorModelContainerGetInputNameFunc + AOTInductorModelContainerGetInputName; +extern AOTInductorModelContainerGetNumConstantsFunc + AOTInductorModelContainerGetNumConstants; + } // extern "C" // AOTI Delegate Handle structure diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abc83779443..7802444e97e 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -145,6 +145,11 @@ void cleanup_tensor_metadata() { internal::tensor_to_strides.clear(); } +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype) { + return dtype_to_element_size(dtype); +} + } // extern "C" } // namespace aoti diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 5f54cd1c878..97fcea1085c 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled); // Cleanup functions for clearing global state void cleanup_tensor_metadata(); +// Needed by Metal backend +size_t aoti_torch_dtype_element_size(int32_t dtype); + } // extern "C" } // namespace aoti From d036c0713348e2482aea4d21405d30a51b629f76 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:08:15 -0400 Subject: [PATCH 02/14] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 173 ++++++++++++++++++ backends/apple/metal/metal_partitioner.py | 77 ++++++++ .../metal/replace_slice_copy_with_slice.py | 118 ++++++++++++ backends/apple/metal/tests/__init__.py | 6 + .../apple/metal/tests/test_metal_backend.py | 80 ++++++++ .../metal/tests/test_metal_partitioner.py | 172 +++++++++++++++++ 6 files changed, 626 insertions(+) create mode 100644 backends/apple/metal/metal_backend.py create mode 100644 backends/apple/metal/metal_partitioner.py create mode 100644 backends/apple/metal/replace_slice_copy_with_slice.py create mode 100644 backends/apple/metal/tests/__init__.py create mode 100644 backends/apple/metal/tests/test_metal_backend.py create mode 100644 backends/apple/metal/tests/test_metal_partitioner.py diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py new file mode 100644 index 00000000000..782aa522084 --- /dev/null +++ b/backends/apple/metal/metal_backend.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from enum import Enum + +from typing import Any, Dict, final, List, Optional, Set + +import torch +from executorch.backends.apple.metal.replace_slice_copy_with_slice import ( + ReplaceSliceCopyWithSlicePass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +# exist fallback operators in et namespace; +supported_fallback_kernels: Dict[str, Any] = { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, +} + +# required fallback kernels but not supported +missing_fallback_kernels: Set[str] = set() + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +# context manager for non-fallback guarantee +# it will raise exception when generating fallback kernels during aoti compile +@contextlib.contextmanager +def collect_unsupported_fallback_kernels(): + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_fallback_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalBackend(BackendDetails): + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + print("entering the lowerable parts in MetalBackend.preprocess....") + # Move the edge_program from CPU to MPS for aoti compile + mps_edge_program = move_to_device_pass(edge_program, "mps") + + # replace slice_copy with slice + ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module) + + edge_program_module = mps_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = mps_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in mps_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Base options for all devices + options: dict[str, typing.Any] = { + # Do not link against the full PyTorch/libtorch library + "aot_inductor.link_libtorch": False, + # Package model constants and other generated files directly in the shared object (.so) file + "aot_inductor.package_constants_in_so": True, + # Enable maximum automatic tuning for optimal performance + "max_autotune": True, + # "aot_inductor.debug_compile": True, + # "aot_inductor.force_mmap_weights": False, + } + + with collect_unsupported_fallback_kernels(): + so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + raise RuntimeError( + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # pyre-ignorep[6]: Incompatible parameter type + with open(so_path, "rb") as f: + so_data = f.read() + + named_data_store = NamedDataStore() + method_name = MetalBackend.method_name_from_compile_specs(compile_specs) + named_data_store.add_named_data( + method_name + "_so_blob", so_data, 1, "aoti_metal_blob" + ) + + # Clean up the generated so file; it has been packaged into the NamdeDataStore + # pyre-ignorep[6]: Incompatible parameter type + os.remove(so_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Returns the compile spec representing the model compute precision, for additional details + please refer to the documentation for ``coremltools.precision``. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Returns the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/apple/metal/metal_partitioner.py b/backends/apple/metal/metal_partitioner.py new file mode 100644 index 00000000000..b103ac0f455 --- /dev/null +++ b/backends/apple/metal/metal_partitioner.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Dict, final, List, Optional, Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend # usort: skip +from executorch.exir._warnings import experimental +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer +from torch.export.exported_program import ExportedProgram + + +@final +@experimental( + "This API and all of Metal backend related functionality are experimental." +) +class MetalPartitioner(Partitioner): + """ + Metal partitioner for AOTInductor backend integration. + + This partitioner creates a single partition containing all operators from the input graph. + It skips core ATen decomposition, allowing the Metal backend to handle decomposition using + AOTInductor's MPS-specific decomposition table. + + Only operators that cannot be handled by the aoti-mps library will be excluded from + the partition and fall back to ExecuTorch's default or custom handling. + """ + + def __init__(self, compile_spec: List[CompileSpec]) -> None: + self.delegation_spec = DelegationSpec(MetalBackend.__name__, compile_spec) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + """ + Fully delegate the graph to AOTInductor by tagging all nodes as a single partition. + """ + + partition_tags: Dict[str, DelegationSpec] = {} + tag = "tag0" + + for node in exported_program.graph.nodes: + if node.op != "call_function": + continue + node.meta["delegation_tag"] = tag + + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + tag_mutated_buffer(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Return a list of operations that should not be decomposed and let the AOT compiler handle them. + Currently we skip ATen decompositon for all ops, and let the Metal backend handle them. + """ + do_not_decompose = set() + + for node in ep.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + do_not_decompose.add(node.target) + return list(do_not_decompose), None diff --git a/backends/apple/metal/replace_slice_copy_with_slice.py b/backends/apple/metal/replace_slice_copy_with_slice.py new file mode 100644 index 00000000000..4f16759af35 --- /dev/null +++ b/backends/apple/metal/replace_slice_copy_with_slice.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Dict, Iterable, Tuple + +import torch +from executorch.exir.dialects._ops import ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch import fx + + +_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = ( + torch.ops.aten.slice_copy.Tensor, + ops.edge.aten.slice_copy.Tensor, +) + +_SLICE_TARGETS: Dict[ + torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload +] = { + torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor, + ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor, +} + + +class ReplaceSliceCopyWithSlicePass(ExportPass): + """Replace non-mutated ``slice_copy`` results with ``slice`` views.""" + + def call(self, graph_module: fx.GraphModule) -> PassResult: + graph_changed = False + + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS: + continue + + if self._has_blocking_user(node, node.users.keys()): + continue + + node.target = _SLICE_TARGETS[node.target] + graph_changed = True + + if graph_changed: + graph_module.graph.lint() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) + + def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool: + for user in users: + if self._is_mutating_user(node, user) or self._is_view_user(node, user): + return True + return False + + def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat in-place tensor methods conservatively as mutations only when the + # method name ends with ``_`` which is the PyTorch convention for mutation. + return isinstance(user.target, str) and user.target.endswith("_") + + if user.op != "call_function": + return False + + target = user.target + if not hasattr(target, "_schema"): + return False + + schema = target._schema # pyre-ignore[16] + # Positional arguments + for index, arg in enumerate(user.args): + if arg is node and self._argument_mutates(schema, index): + return True + + # Keyword arguments + for name, arg in user.kwargs.items(): + if arg is node and self._argument_mutates(schema, name): + return True + + return False + + def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool: + if user.op == "call_method": + # Treat tensor methods conservatively and assume they may be view-producing. + return True + + if user.op != "call_function": + return False + + target = user.target + if getattr(target, "is_view", False): + for arg in user.args: + if arg is node: + return True + for arg in user.kwargs.values(): + if arg is node: + return True + + return False + + def _argument_mutates( + self, schema: torch._C.FunctionSchema, key: int | str + ) -> bool: + arguments = schema.arguments + if isinstance(key, int): + if key >= len(arguments): + return False + argument = arguments[key] + else: + argument = next((arg for arg in arguments if arg.name == key), None) + if argument is None: + return False + + alias_info = argument.alias_info + return bool(alias_info and alias_info.is_write) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py new file mode 100644 index 00000000000..fd6404c7f7b --- /dev/null +++ b/backends/apple/metal/tests/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py new file mode 100644 index 00000000000..26d2281c458 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from executorch.backends.apple.metal.metal_backend import ( + COMPILE_SPEC_KEYS, + MetalBackend, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec + + +class TestMetalBackend(unittest.TestCase): + """Test Metal backend utility functions.""" + + def test_generate_method_name_compile_spec(self): + """Test that compile spec is generated correctly with method name.""" + method_name = "forward" + compile_spec = MetalBackend.generate_method_name_compile_spec(method_name) + + # Verify compile spec structure + self.assertIsInstance(compile_spec, CompileSpec) + self.assertEqual(compile_spec.key, COMPILE_SPEC_KEYS.METHOD_NAME.value) + self.assertEqual(compile_spec.value, method_name.encode("utf-8")) + + def test_method_name_from_compile_specs(self): + """Test extracting method name from compile specs.""" + method_name = "forward" + compile_specs = [MetalBackend.generate_method_name_compile_spec(method_name)] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_with_multiple_specs(self): + """Test extracting method name when there are multiple compile specs.""" + method_name = "forward" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + MetalBackend.generate_method_name_compile_spec(method_name), + CompileSpec("another_key", b"another_value"), + ] + + # Extract method name + extracted_name = MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertEqual(extracted_name, method_name) + + def test_method_name_from_compile_specs_missing(self): + """Test that RuntimeError is raised when method name is missing.""" + compile_specs = [ + CompileSpec("other_key", b"other_value"), + ] + + # Should raise RuntimeError when method name is not found + with self.assertRaises(RuntimeError) as context: + MetalBackend.method_name_from_compile_specs(compile_specs) + + self.assertIn("Could not find method name", str(context.exception)) + + def test_compile_spec_roundtrip(self): + """Test that method name survives encode/decode roundtrip.""" + original_name = "my_custom_method" + + # Generate compile spec + compile_spec = MetalBackend.generate_method_name_compile_spec(original_name) + + # Extract from compile specs list + extracted_name = MetalBackend.method_name_from_compile_specs([compile_spec]) + + self.assertEqual(original_name, extracted_name) + + +if __name__ == "__main__": + unittest.main() + diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py new file mode 100644 index 00000000000..97a073152f5 --- /dev/null +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Tuple + +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir.backend.partitioner import PartitionResult +from torch.export import export + + +class TestMetalPartitioner(unittest.TestCase): + """ + Test Metal partitioner functionality. + + After Metal partitioning, there should be exactly one partitioned graph that contains + all operators from the input graph. This means all operators should be tagged with + the same delegation tag, indicating they will all be executed by the Metal backend. + """ + + def _get_partition_result( + self, module: torch.nn.Module, inputs: Tuple[torch.Tensor, ...] + ) -> PartitionResult: + """Helper method to get partition result for a given module.""" + # Export the model + exported_program = export(module, inputs, strict=True) + + # Create partitioner with compile specs + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get partition result + partition_result = partitioner.partition(exported_program) + + # Verify partition result structure + self.assertIsNotNone(partition_result) + self.assertTrue(hasattr(partition_result, "tagged_exported_program")) + self.assertTrue(hasattr(partition_result, "partition_tags")) + + return partition_result + + def _check_fully_partitioned(self, partition_result: PartitionResult) -> bool: + """Check if the graph is fully partitioned (all operators have the same tag).""" + tagged_nodes = [] + untagged_ops = [] + + for node in partition_result.tagged_exported_program.graph.nodes: + if node.op == "call_function": + if hasattr(node, "meta") and "delegation_tag" in node.meta: + tagged_nodes.append(node) + else: + untagged_ops.append(node) + + # Check if we have any tagged nodes + if not tagged_nodes: + return False + + # Check if all tagged nodes have the same tag + first_tag = tagged_nodes[0].meta["delegation_tag"] + all_same_tag = all( + node.meta.get("delegation_tag") == first_tag for node in tagged_nodes + ) + + # Should have no untagged operations for full partitioning + fully_partitioned = len(untagged_ops) == 0 and all_same_tag + + return fully_partitioned + + def test_simple_add_partition(self): + """ + Test that Metal partitioner creates exactly one partition containing all operators. + Simple element-wise addition should result in a single graph with all ops tagged identically. + """ + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + # Create test inputs + x = torch.randn(2, 3) + y = torch.randn(2, 3) + + # Get partition result + partition_result = self._get_partition_result(AddModule(), (x, y)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + # Verify exactly one partition tag exists + self.assertEqual( + len(partition_result.partition_tags), + 1, + "Expected exactly one partition tag for fully delegated graph", + ) + + def test_linear_partition(self): + """ + Test Metal partitioner with a linear layer. + All matrix operations should be in a single partition. + """ + + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + # Create test input + x = torch.randn(2, 10) + + # Get partition result + partition_result = self._get_partition_result(LinearModule(), (x,)) + + # Verify it's fully partitioned + self.assertTrue( + self._check_fully_partitioned(partition_result), + "Expected all operations to be in a single partition", + ) + + def test_ops_to_not_decompose(self): + """ + Test that ops_to_not_decompose returns all call_function ops. + Metal backend should handle decomposition via AOTInductor. + """ + + class SimpleModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.relu(x + 1.0) + + # Create test input + x = torch.randn(2, 3) + + # Export the model + exported_program = export(SimpleModule(), (x,), strict=True) + + # Create partitioner + compile_specs = [MetalBackend.generate_method_name_compile_spec("forward")] + partitioner = MetalPartitioner(compile_specs) + + # Get ops to not decompose + ops_to_not_decompose, _ = partitioner.ops_to_not_decompose(exported_program) + + # Verify it returns a list + self.assertIsInstance(ops_to_not_decompose, list) + + # All call_function ops should be in the list + call_function_ops = [ + node.target + for node in exported_program.graph.nodes + if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + ] + + self.assertEqual( + set(ops_to_not_decompose), + set(call_function_ops), + "ops_to_not_decompose should contain all call_function ops", + ) + + +if __name__ == "__main__": + unittest.main() + From 1a22c5e02f3a4d57bb3abf6b68c1279ec4bf58e4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 16:21:06 -0400 Subject: [PATCH 03/14] Update [ghstack-poisoned] --- backends/apple/metal/tests/__init__.py | 1 - backends/apple/metal/tests/test_metal_backend.py | 1 - backends/apple/metal/tests/test_metal_partitioner.py | 4 ++-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/backends/apple/metal/tests/__init__.py b/backends/apple/metal/tests/__init__.py index fd6404c7f7b..2e41cd717f6 100644 --- a/backends/apple/metal/tests/__init__.py +++ b/backends/apple/metal/tests/__init__.py @@ -3,4 +3,3 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - diff --git a/backends/apple/metal/tests/test_metal_backend.py b/backends/apple/metal/tests/test_metal_backend.py index 26d2281c458..5caf7a3adc6 100644 --- a/backends/apple/metal/tests/test_metal_backend.py +++ b/backends/apple/metal/tests/test_metal_backend.py @@ -77,4 +77,3 @@ def test_compile_spec_roundtrip(self): if __name__ == "__main__": unittest.main() - diff --git a/backends/apple/metal/tests/test_metal_partitioner.py b/backends/apple/metal/tests/test_metal_partitioner.py index 97a073152f5..1b29410ab6c 100644 --- a/backends/apple/metal/tests/test_metal_partitioner.py +++ b/backends/apple/metal/tests/test_metal_partitioner.py @@ -157,7 +157,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: call_function_ops = [ node.target for node in exported_program.graph.nodes - if node.op == "call_function" and isinstance(node.target, torch._ops.OpOverload) + if node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) ] self.assertEqual( @@ -169,4 +170,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": unittest.main() - From d6f0bc952a57a6ec14c5118fd1fe4da91d2d3194 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:30 -0400 Subject: [PATCH 04/14] Update [ghstack-poisoned] --- .../metal/runtime/shims/tensor_attribute.cpp | 38 ++++++++ .../metal/runtime/shims/tensor_attribute.h | 32 +++++++ backends/apple/metal/runtime/shims/types.h | 35 +++++++ backends/apple/metal/runtime/shims/utils.cpp | 93 +++++++++++++++++++ backends/apple/metal/runtime/shims/utils.h | 74 +++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.cpp create mode 100644 backends/apple/metal/runtime/shims/tensor_attribute.h create mode 100644 backends/apple/metal/runtime/shims/types.h create mode 100644 backends/apple/metal/runtime/shims/utils.cpp create mode 100644 backends/apple/metal/runtime/shims/utils.h diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.cpp b/backends/apple/metal/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..684e00ffe32 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type constant +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + // Let's use 2 for MPS + return 2; +} + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.h b/backends/apple/metal/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..8d2a3dde361 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type function +int32_t aoti_torch_device_type_mps(); + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/types.h b/backends/apple/metal/runtime/shims/types.h new file mode 100644 index 00000000000..07d377d7499 --- /dev/null +++ b/backends/apple/metal/runtime/shims/types.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp new file mode 100644 index 00000000000..484158e9027 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_metal(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. it is ok if provided strides here is not contiguous + // strides since it will be used internally in CUDA delegate. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes using ExecutorTorch's algorithm + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h new file mode 100644 index 00000000000..5b9f9c5b3bb --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Enum for supported data types in et-metal backend +enum class SupportedDTypes : int32_t { + // UINT8 = 0, // PyTorch's uint8 dtype code + // INT8 = 1, // PyTorch's int8 dtype code + // INT16 = 2, // PyTorch's int16 dtype code + // INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + // FLOAT16 = 5, // PyTorch's float16 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + // FLOAT64 = 7, // PyTorch's float64 dtype code + // BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype); + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +// Utility function to convert sizes pointer to vector +std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr); + +// Utility function to convert strides pointer to vector or calculate from sizes +std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr); + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_contiguous_tensor( + std::vector sizes, + std::vector strides) { + int64_t ndim = static_cast(strides.size()); + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + +} // namespace metal +} // namespace backends +} // namespace executorch From 7e11615aa43033df7f5988d5c1adb6d923c78297 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:34 -0400 Subject: [PATCH 05/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 271 ++++++ .../apple/metal/runtime/shims/et_metal.mm | 872 ++++++++++++++++++ 2 files changed, 1143 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/et_metal.h create mode 100644 backends/apple/metal/runtime/shims/et_metal.mm diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h new file mode 100644 index 00000000000..c18ad513a3a --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -0,0 +1,271 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __OBJC__ +#import +#import +#include +// Forward declarations for MetalPerformanceShadersGraph types +@class MPSGraph; +@class MPSCommandBuffer; +// Metal type definitions for Objective-C compilation +typedef id MTLDevice_t; +typedef id MTLCommandQueue_t; +typedef id MTLCommandBuffer_t; +typedef id MTLComputeCommandEncoder_t; +typedef id MTLComputePipelineState_t; +typedef id MTLFunction_t; +typedef id MTLLibrary_t; +typedef id MTLBuffer_t; +typedef dispatch_queue_t dispatch_queue_t; +typedef MPSGraph* MPSGraph_t; +typedef MPSCommandBuffer* MPSCommandBuffer_t; +typedef NSDictionary* NSDictionary_t; +#else +// Forward declarations for C++ compilation +typedef void* MTLDevice_t; +typedef void* MTLCommandQueue_t; +typedef void* MTLCommandBuffer_t; +typedef void* MTLComputeCommandEncoder_t; +typedef void* MTLComputePipelineState_t; +typedef void* MTLFunction_t; +typedef void* MTLLibrary_t; +typedef void* MTLBuffer_t; +typedef void* dispatch_queue_t; +typedef void* MPSGraph_t; +typedef void* MPSCommandBuffer_t; +typedef void* NSDictionary_t; +#endif + +#include +#include +#include +#include +#include + +namespace executorch::runtime::etensor { +class Tensor; +} + +namespace executorch { +namespace backends { +namespace metal { + +// Forward declarations +class ETMetalKernelFunction; +class ETMetalStream; + +// ======================= +// SyncType - Metal synchronization options +// ======================= +enum class SyncType { + NONE, // no commit to command buffer + COMMIT, // commit and flush the command buffer + COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish + COMMIT_AND_CONTINUE, // commit and continue with a new underlying command + // buffer + COMMIT_ADAPTIVE, // commit adaptively based on available memory +}; + +// ======================= +// ETMetalShaderLibrary - ExecuTorch Metal shader library management +// ======================= +class ETMetalShaderLibrary { + public: + ETMetalShaderLibrary(const std::string& source); + ~ETMetalShaderLibrary(); + + std::shared_ptr getKernelFunction( + const std::string& name); + + private: + void compileLibrary(); + std::pair getLibraryPipelineState( + const std::string& functionName); + + friend class ETMetalKernelFunction; + + std::string shaderSource_; + MTLLibrary_t library_; + std::unordered_map< + std::string, + std::pair> + pipelineStates_; +}; + +// ======================= +// ETMetalKernelFunction - ExecuTorch Metal kernel function execution +// ======================= +class ETMetalKernelFunction { + public: + ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); + ~ETMetalKernelFunction(); + + void startEncoding(); + void setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor); + void setArg(unsigned idx, int64_t val); + + void dispatchSingle(uint64_t length); + void dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size); + void dispatchArray(const uint64_t* length, size_t length_size); + void dispatchArrayWithGroupSize( + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + + void runCommandBlock(std::function f); + + private: + MTLComputePipelineState_t cps_; + MTLFunction_t func_; + MTLComputeCommandEncoder_t encoder_; +}; + +// ======================= +// ETMetalStream - Metal command buffer and synchronization management +// ======================= +class ETMetalStream { + public: + ETMetalStream(); + ~ETMetalStream(); + + // Get the default stream (singleton) + static ETMetalStream* getDefaultStream(); + + // Device and queue access + MTLDevice_t device() const { + return device_; + } + MTLCommandQueue_t commandQueue() const { + return commandQueue_; + } + dispatch_queue_t queue() const { + return serialQueue_; + } + + // Synchronization methods + void synchronize(SyncType syncType = SyncType::COMMIT_AND_WAIT); + void synchronize(); // Overload for backward compatibility + bool isEmpty() const; + + // Command buffer management with lazy creation + MPSCommandBuffer_t commandBuffer(); + MTLComputeCommandEncoder_t commandEncoder(); + + void endKernelCoalescing(); + + // MPSGraph execution + void executeMPSGraph( + MPSGraph_t mpsGraph, + NSDictionary_t feeds, + NSDictionary_t results, + SyncType syncType = SyncType::COMMIT_ADAPTIVE); + + // Command buffer lifecycle management + void commitCommandBuffer(MTLCommandBuffer_t commandBuffer); + void flush(); + + // Memory operations + void fill( + MTLBuffer_t buffer, + uint8_t value, + size_t length, + size_t offset, + SyncType syncType = SyncType::NONE); + void copy( + MTLBuffer_t srcBuffer, + MTLBuffer_t dstBuffer, + size_t length, + size_t srcOffset, + size_t dstOffset, + SyncType syncType = SyncType::NONE); + + private: + // Private synchronization methods + void commit(); + void commitAndWait(); + void commitAndContinue(); + + private: + // Private members + MTLDevice_t device_; + MTLCommandQueue_t commandQueue_; + MPSCommandBuffer_t commandBuffer_; + MPSCommandBuffer_t prevCommandBuffer_; // For commit-and-continue pattern + MTLComputeCommandEncoder_t commandEncoder_; + dispatch_queue_t serialQueue_; // For thread safety + + // Configuration + bool enableCommitAndContinue_; + + // Singleton instance + static ETMetalStream* defaultStream_; +}; + +// ======================= +// Global storage management functions +// ======================= +void storeFunctionHandle( + ETMetalKernelFunction* raw_function, + std::shared_ptr function_shared_ptr); +void storeLibraryHandle( + ETMetalShaderLibrary* raw_library, + std::unique_ptr library); +bool removeFunctionHandle(ETMetalKernelFunction* raw_function); +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library); + +// ======================= +// Global stream access functions +// ======================= +ETMetalStream* getCurrentMetalStream(); +void setCurrentMetalStream(ETMetalStream* stream); + +// ======================= +// Metal stream synchronization functions (C++ interface with exceptions) +// ======================= +void synchronize_metal_stream(); +void synchronize_metal_stream_with_type(int sync_type); + +// ======================= +// Metal helper functions (C interface) +// ======================= +#ifdef __cplusplus +extern "C" { +#endif + +// Memory management functions for Metal +void* metal_allocate_buffer(long bytes); +bool metal_is_device_pointer(void* ptr); +int metal_copy_memory( + void* dst, + const void* src, + size_t nbytes, + bool src_is_device, + bool dst_is_device); +void metal_cleanup_resources(); + +// Helper functions to access Metal objects +MTLDevice_t get_metal_device(); +MTLCommandQueue_t get_metal_command_queue(); + +#ifdef __cplusplus +} + +// C++ only - expose the Metal buffer mapping +#ifdef __OBJC__ +extern std::unordered_map ptr_to_mtl_buffer; +#endif + +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm new file mode 100644 index 00000000000..5afcf761d56 --- /dev/null +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -0,0 +1,872 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// ======================= +// Exception-Safe Dispatch Function (similar to PyTorch MPS) +// ======================= + +void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { + __block std::optional block_exception; + dispatch_sync(queue, ^() { + try { + block(); + } catch (...) { + block_exception = std::current_exception(); + } + }); + if (block_exception) { + std::rethrow_exception(*block_exception); + } +} + +// ======================= +// Global Variables and Storage +// ================ + + +// Global Metal buffer mapping - accessible for MPS shim +std::unordered_map> ptr_to_mtl_buffer; + +// Global storage to keep shared_ptr alive while raw pointers are used +static std::unordered_map> function_storage; +static std::unordered_map> library_storage; + +// Static singleton instance for default stream +ETMetalStream* ETMetalStream::defaultStream_ = nullptr; + +// Thread-local current stream +static thread_local ETMetalStream* currentStream_ = nullptr; + +// ======================= +// Metal Helper Functions (C Interface) +// ======================= + +extern "C" { + +void* metal_allocate_buffer(long bytes) { + ETMetalStream* stream = getCurrentMetalStream(); + id device = stream->device(); + if (!device) { + ET_LOG(Error, "Failed to get Metal device from stream"); + return nullptr; + } + + @autoreleasepool { + id buffer = [device newBufferWithLength:bytes options:MTLResourceStorageModeShared]; + if (!buffer) { + ET_LOG(Error, "Failed to allocate %ld bytes on Metal device", bytes); + return nullptr; + } + + void* ptr = [buffer contents]; + ptr_to_mtl_buffer[ptr] = buffer; + + ET_LOG(Debug, "Allocated %ld bytes on Metal device", bytes); + return ptr; + } +} + +void metal_cleanup_resources() { + if (!ptr_to_mtl_buffer.empty()) { + @autoreleasepool { + for (auto& pair : ptr_to_mtl_buffer) { + pair.second = nil; + } + ptr_to_mtl_buffer.clear(); + } + } +} + +bool metal_is_device_pointer(void* ptr) { + return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); +} + +int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_device, bool dst_is_device) { + if (!src || !dst || nbytes == 0) { + ET_LOG(Error, "Metal copy: Invalid parameters"); + return -1; + } + + @autoreleasepool { + // Case 1: Device-to-device copy - use GPU blit encoder (most efficient) + if (src_is_device && dst_is_device) { + auto src_it = ptr_to_mtl_buffer.find(const_cast(src)); + auto dst_it = ptr_to_mtl_buffer.find(dst); + + if (src_it != ptr_to_mtl_buffer.end() && dst_it != ptr_to_mtl_buffer.end()) { + id srcBuffer = src_it->second; + id dstBuffer = dst_it->second; + + // Calculate offsets relative to buffer base + size_t srcOffset = static_cast(src) - static_cast([srcBuffer contents]); + size_t dstOffset = static_cast(dst) - static_cast([dstBuffer contents]); + + // Use Metal's blit encoder for GPU-accelerated copy + ETMetalStream* stream = getCurrentMetalStream(); + stream->copy(srcBuffer, dstBuffer, nbytes, srcOffset, dstOffset, SyncType::NONE); + + ET_LOG(Debug, "Metal device-to-device copy (GPU blit): %zu bytes", nbytes); + return 0; + } + + ET_LOG(Error, "Metal copy: Device pointers not found in buffer map"); + return -1; + } + + // Case 2: Host-to-device or device-to-host - use memcpy with shared memory + // Since Metal uses shared storage mode, CPU and GPU access the same memory + std::memcpy(dst, src, nbytes); + + // Synchronize only if we need to ensure GPU operations complete before CPU reads + // (device-to-host case where GPU may have written data) + if (src_is_device && !dst_is_device) { + // Ensure any pending GPU writes to source complete before CPU reads + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + } + + ET_LOG(Debug, "Metal memory copy (memcpy): %zu bytes, src_device=%d, dst_device=%d", + nbytes, src_is_device, dst_is_device); + } + + return 0; +} + +id get_metal_device() { + // Use stream-based device access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->device(); +} + +id get_metal_command_queue() { + // Use stream-based queue access + ETMetalStream* stream = getCurrentMetalStream(); + return stream->commandQueue(); +} + +} // extern "C" + +// ======================= +// ETMetalShaderLibrary Implementation +// ======================= + +ETMetalShaderLibrary::ETMetalShaderLibrary(const std::string& source) : shaderSource_(source) { + compileLibrary(); +} + +ETMetalShaderLibrary::~ETMetalShaderLibrary() { + @autoreleasepool { + if (library_) { + [library_ release]; + library_ = nil; + } + + for (auto& pair : pipelineStates_) { + [pair.second.first release]; + [pair.second.second release]; + } + pipelineStates_.clear(); + } +} + +void ETMetalShaderLibrary::compileLibrary() { + @autoreleasepool { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return; + } + + NSString* sourceString = [NSString stringWithUTF8String:shaderSource_.c_str()]; + NSError* error = nil; + + library_ = [device newLibraryWithSource:sourceString options:nil error:&error]; + if (!library_ || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to compile shader library: %s", + error ? [[error localizedDescription] UTF8String] : "unknown error"); + return; + } + + [library_ retain]; + ET_LOG(Debug, "ETMetalShaderLibrary: Successfully compiled shader library"); + } +} + +std::pair, id> ETMetalShaderLibrary::getLibraryPipelineState(const std::string& functionName) { + auto it = pipelineStates_.find(functionName); + if (it != pipelineStates_.end()) { + return it->second; + } + + @autoreleasepool { + if (!library_) { + ET_LOG(Error, "ETMetalShaderLibrary: Library not compiled"); + return {nil, nil}; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get Metal device"); + return {nil, nil}; + } + + NSString* funcName = [NSString stringWithUTF8String:functionName.c_str()]; + id function = [library_ newFunctionWithName:funcName]; + if (!function) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to get function '%s'", functionName.c_str()); + return {nil, nil}; + } + + NSError* error = nil; + id pipelineState = [device newComputePipelineStateWithFunction:function error:&error]; + if (!pipelineState || error) { + ET_LOG(Error, "ETMetalShaderLibrary: Failed to create pipeline state for '%s': %s", + functionName.c_str(), error ? [[error localizedDescription] UTF8String] : "unknown error"); + [function release]; + return {nil, nil}; + } + + [pipelineState retain]; + [function retain]; + pipelineStates_[functionName] = {pipelineState, function}; + + ET_LOG(Debug, "ETMetalShaderLibrary: Created pipeline state for function '%s'", functionName.c_str()); + return {pipelineState, function}; + } +} + +std::shared_ptr ETMetalShaderLibrary::getKernelFunction(const std::string& name) { + auto pipelineStatePair = getLibraryPipelineState(name); + if (!pipelineStatePair.first || !pipelineStatePair.second) { + ET_LOG(Error, "ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for '%s'", name.c_str()); + return nullptr; + } + + return std::make_shared(pipelineStatePair.first, pipelineStatePair.second); +} + +// ======================= +// ETMetalKernelFunction Implementation +// ======================= + +ETMetalKernelFunction::ETMetalKernelFunction(id cps, id func) + : cps_(cps), func_(func), encoder_(nil) { + if (cps_) [cps_ retain]; + if (func_) [func_ retain]; +} + +ETMetalKernelFunction::~ETMetalKernelFunction() { + @autoreleasepool { + // Don't release encoder_ here - the stream owns it + // Only clean up our own references + if (cps_) { + [cps_ release]; + cps_ = nil; + } + if (func_) { + [func_ release]; + func_ = nil; + } + + encoder_ = nil; // Clear reference without releasing + } +} + +void ETMetalKernelFunction::startEncoding() { + @autoreleasepool { + // Don't retain/release the encoder - just get reference from stream + ETMetalStream* stream = getCurrentMetalStream(); + encoder_ = stream->commandEncoder(); // Use stream's managed encoder + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction: Failed to get encoder from stream"); + return; + } + + // Don't retain - stream owns the encoder + [encoder_ setComputePipelineState:cps_]; + + ET_LOG(Debug, "ETMetalKernelFunction: Started encoding with stream-managed encoder"); + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, const executorch::runtime::etensor::Tensor& tensor) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + void* data_ptr = tensor.mutable_data_ptr(); + size_t totalSize = tensor.numel() * tensor.element_size(); + + auto it = ptr_to_mtl_buffer.find(data_ptr); + if (it != ptr_to_mtl_buffer.end()) { + // Use existing Metal buffer + id mtlBuffer = it->second; + [encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set Metal buffer at index %u (size: %zu)", idx, totalSize); + } else { + // Handle CPU tensor data + if (totalSize <= 4096) { + // Use setBytes for small data (more efficient) + [encoder_ setBytes:data_ptr length:totalSize atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set CPU tensor via setBytes at index %u (size: %zu)", idx, totalSize); + } else { + // Create temporary buffer for large data (should be rare) + @autoreleasepool { + id device = get_metal_device(); + if (device) { + id tempBuffer = [device newBufferWithBytes:data_ptr + length:totalSize + options:MTLResourceStorageModeShared]; + if (tempBuffer) { + [encoder_ setBuffer:tempBuffer offset:0 atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set large CPU tensor via temporary buffer at index %u (size: %zu)", idx, totalSize); + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: Failed to create temporary buffer for index %u", idx); + } + } else { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No Metal device available for index %u", idx); + } + } + } + } +} + +void ETMetalKernelFunction::setArg(unsigned idx, int64_t val) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); + return; + } + + [encoder_ setBytes:&val length:sizeof(int64_t) atIndex:idx]; + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); +} + +void ETMetalKernelFunction::dispatchSingle(uint64_t length) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchSingleWithGroupSize: No active encoder"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + uint64_t actualGroupSize = group_size > 0 ? std::min(group_size, maxThreadsPerGroup) : std::min(maxThreadsPerGroup, length); + + auto size = MTLSizeMake(length, 1, 1); + auto threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingleWithGroupSize: Dispatched with length %llu, group size %llu", length, actualGroupSize); + +} + +void ETMetalKernelFunction::dispatchArray(const uint64_t* length, size_t length_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArray: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = std::min(maxThreadsPerGroup, length[0]); + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArray: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::dispatchArrayWithGroupSize(const uint64_t* length, size_t length_size, + const uint64_t* group_size, size_t group_size_size) { + if (!encoder_) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: No active encoder"); + return; + } + + if (!length || length_size == 0) { + ET_LOG(Error, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Invalid length array"); + return; + } + + const auto maxThreadsPerGroup = static_cast([cps_ maxTotalThreadsPerThreadgroup]); + + MTLSize size, threadGroupSize; + + if (length_size == 1) { + size = MTLSizeMake(length[0], 1, 1); + uint64_t actualGroupSize = maxThreadsPerGroup; + if (group_size && group_size_size > 0) { + actualGroupSize = std::min(maxThreadsPerGroup, group_size[0]); + } + threadGroupSize = MTLSizeMake(actualGroupSize, 1, 1); + } else if (length_size == 2) { + size = MTLSizeMake(length[0], length[1], 1); + uint64_t groupX = std::min(static_cast(32), length[0]); + uint64_t groupY = maxThreadsPerGroup / groupX; + if (group_size && group_size_size >= 2) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + } + threadGroupSize = MTLSizeMake(groupX, groupY, 1); + } else { + size = MTLSizeMake(length[0], length[1], length_size > 2 ? length[2] : 1); + uint64_t groupX = std::min(static_cast(8), length[0]); + uint64_t groupY = std::min(static_cast(8), length[1]); + uint64_t groupZ = maxThreadsPerGroup / (groupX * groupY); + if (group_size && group_size_size >= 3) { + groupX = std::min(static_cast(group_size[0]), length[0]); + groupY = std::min(static_cast(group_size[1]), length[1]); + groupZ = std::min(static_cast(group_size[2]), length_size > 2 ? length[2] : 1); + } + threadGroupSize = MTLSizeMake(groupX, groupY, groupZ); + } + + [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; + ET_LOG(Debug, "ETMetalKernelFunction::dispatchArrayWithGroupSize: Dispatched %zuD with size [%lu, %lu, %lu], group [%lu, %lu, %lu]", + length_size, size.width, size.height, size.depth, + threadGroupSize.width, threadGroupSize.height, threadGroupSize.depth); + +} + +void ETMetalKernelFunction::runCommandBlock(std::function f) { + // Use dispatch_sync with the stream's serial queue for thread safety and synchronization + // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) + ETMetalStream* stream = getCurrentMetalStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + f(); + } + }); + + ET_LOG(Debug, "ETMetalKernelFunction::runCommandBlock: Executed command block with dispatch_sync"); +} + +// ======================= +// ETMetalStream Implementation +// ======================= + +ETMetalStream::ETMetalStream() + : device_(nil), commandQueue_(nil), commandBuffer_(nil), prevCommandBuffer_(nil), + commandEncoder_(nil), serialQueue_(nullptr), enableCommitAndContinue_(true) { + @autoreleasepool { + // Create device and command queue + device_ = MTLCreateSystemDefaultDevice(); + if (!device_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal device"); + return; + } + [device_ retain]; + + commandQueue_ = [device_ newCommandQueue]; + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream: Failed to create Metal command queue"); + return; + } + [commandQueue_ retain]; + + // Create serial queue for thread safety + serialQueue_ = dispatch_queue_create("metal gpu stream", nullptr); + + ET_LOG(Debug, "ETMetalStream: Created stream with device %p, queue %p", device_, commandQueue_); + } +} + +ETMetalStream::~ETMetalStream() { + @autoreleasepool { + // Synchronize before cleanup + synchronize(SyncType::COMMIT_AND_WAIT); + + // Clean up command encoder + if (commandEncoder_) { + [commandEncoder_ release]; + commandEncoder_ = nil; + } + + // Clean up command buffers + if (commandBuffer_) { + [commandBuffer_ release]; + commandBuffer_ = nil; + } + if (prevCommandBuffer_) { + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Clean up command queue and device + if (commandQueue_) { + [commandQueue_ release]; + commandQueue_ = nil; + } + if (device_) { + [device_ release]; + device_ = nil; + } + + // Clean up serial queue + if (serialQueue_) { + dispatch_release(serialQueue_); + serialQueue_ = nullptr; + } + + ET_LOG(Debug, "ETMetalStream: Destroyed stream"); + } +} + +ETMetalStream* ETMetalStream::getDefaultStream() { + if (!defaultStream_) { + defaultStream_ = new ETMetalStream(); + } + return defaultStream_; +} + +// Lazy command buffer creation (use MPSCommandBuffer like PyTorch) +MPSCommandBuffer* ETMetalStream::commandBuffer() { + if (!commandBuffer_) { + if (!commandQueue_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: No command queue available"); + return nil; + } + + commandBuffer_ = [MPSCommandBuffer commandBufferFromCommandQueue:commandQueue_]; + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commandBuffer: Failed to create command buffer"); + return nil; + } + [commandBuffer_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandBuffer: Created lazy command buffer %p", commandBuffer_); + } + + return commandBuffer_; +} + +// Lazy command encoder creation +id ETMetalStream::commandEncoder() { + if (!commandEncoder_) { + MPSCommandBuffer* cmdBuffer = commandBuffer(); + if (!cmdBuffer) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to get command buffer"); + return nil; + } + + commandEncoder_ = [cmdBuffer computeCommandEncoder]; + if (!commandEncoder_) { + ET_LOG(Error, "ETMetalStream::commandEncoder: Failed to create command encoder"); + return nil; + } + [commandEncoder_ retain]; + + ET_LOG(Debug, "ETMetalStream::commandEncoder: Created lazy command encoder %p", commandEncoder_); + } + + return commandEncoder_; +} + +// Synchronization with SyncType - matches PyTorch's approach (no dispatch_sync here) +void ETMetalStream::synchronize(SyncType syncType) { + endKernelCoalescing(); + + switch (syncType) { + case SyncType::NONE: + // Do nothing - no commit + break; + case SyncType::COMMIT: + commit(); + break; + case SyncType::COMMIT_AND_WAIT: + commitAndWait(); + break; + case SyncType::COMMIT_AND_CONTINUE: + if (enableCommitAndContinue_) { + commitAndContinue(); + } else { + ET_LOG(Error, "ETMetalStream::synchronize: CommitAndContinue requested but disabled"); + commit(); + } + break; + case SyncType::COMMIT_ADAPTIVE: + // Simple adaptive policy - could be enhanced with memory pressure detection + // TODO: Could add memory pressure detection like PyTorch does + commit(); + break; + } + + ET_LOG(Debug, "ETMetalStream::synchronize: Completed with SyncType %d", static_cast(syncType)); +} + +// Encoder coalescing management +void ETMetalStream::endKernelCoalescing() { + if (commandEncoder_) { + [commandEncoder_ endEncoding]; + [commandEncoder_ release]; + commandEncoder_ = nil; + ET_LOG(Debug, "ETMetalStream::endKernelCoalescing: Ended encoder coalescing"); + } +} + +// Commit methods +void ETMetalStream::commit() { + if (enableCommitAndContinue_ && commandBuffer_) { + // Use commit-and-continue for better performance + commitAndContinue(); + } else { + flush(); + } +} + +void ETMetalStream::commitAndWait() { + // Handle previous command buffer first + if (prevCommandBuffer_) { + [prevCommandBuffer_ waitUntilCompleted]; + [prevCommandBuffer_ release]; + prevCommandBuffer_ = nil; + } + + // Handle current command buffer + if (commandBuffer_) { + [commandBuffer_ commit]; + [commandBuffer_ waitUntilCompleted]; + [commandBuffer_ release]; + commandBuffer_ = nil; + } + + ET_LOG(Debug, "ETMetalStream::commitAndWait: Committed and waited for completion"); +} + +void ETMetalStream::commitAndContinue() { + if (!commandBuffer_) { + ET_LOG(Error, "ETMetalStream::commitAndContinue: No command buffer to commit"); + return; + } + + // Commit buffer and allow immediate reuse for better performance + [commandBuffer_ commit]; + ET_LOG(Debug, "ETMetalStream::commitAndContinue: Committed buffer %p with continue", commandBuffer_); + + // The buffer handles synchronization internally for commit-and-continue +} + +void ETMetalStream::flush() { + if (commandBuffer_) { + [commandBuffer_ commit]; + + if (!enableCommitAndContinue_) { + // Keep the command buffer for later waiting if commit-and-continue is disabled + prevCommandBuffer_ = commandBuffer_; + } else { + [commandBuffer_ release]; + } + commandBuffer_ = nil; + + ET_LOG(Debug, "ETMetalStream::flush: Flushed command buffer"); + } +} + +// Memory operations +void ETMetalStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { + if (length == 0) { + return; + } + + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + [blitEncoder fillBuffer:buffer range:NSMakeRange(offset, length) value:value]; + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::fill: Filled buffer with value %u, length %zu, offset %zu", value, length, offset); + } + }); +} + +void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, + size_t srcOffset, size_t dstOffset, SyncType syncType) { + dispatch_sync(serialQueue_, ^{ + @autoreleasepool { + endKernelCoalescing(); + id blitEncoder = [commandBuffer() blitCommandEncoder]; + + // Handle large copies in chunks + constexpr size_t max_copy_size = 0x80000000; // 2GB + size_t bytes_copied = 0; + size_t bytes_remaining = length; + + while (bytes_remaining > 0) { + NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remaining); + [blitEncoder copyFromBuffer:srcBuffer + sourceOffset:(NSUInteger)srcOffset + bytes_copied + toBuffer:dstBuffer + destinationOffset:(NSUInteger)dstOffset + bytes_copied + size:bytes_to_copy]; + bytes_copied += bytes_to_copy; + bytes_remaining -= bytes_to_copy; + } + + [blitEncoder endEncoding]; + synchronize(syncType); + + ET_LOG(Debug, "ETMetalStream::copy: Copied %zu bytes from offset %zu to offset %zu", length, srcOffset, dstOffset); + } + }); +} + + +void ETMetalStream::synchronize() { + synchronize(SyncType::COMMIT_AND_WAIT); +} + +bool ETMetalStream::isEmpty() const { + return !commandBuffer_ && !commandEncoder_; +} + +void ETMetalStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) { + // Use dispatch_sync_with_rethrow exactly like PyTorch does for MPSGraph execution + dispatch_sync_with_rethrow(serialQueue_, ^() { + @autoreleasepool { + endKernelCoalescing(); + + [mpsGraph encodeToCommandBuffer:commandBuffer() + feeds:feeds + targetOperations:nil + resultsDictionary:results + executionDescriptor:nil]; + + //synchronize(syncType); + } + }); +} + +// ======================= +// Global Storage Management Functions +// ======================= + +void storeFunctionHandle(ETMetalKernelFunction* raw_function, std::shared_ptr function_shared_ptr) { + function_storage[raw_function] = function_shared_ptr; +} + +void storeLibraryHandle(ETMetalShaderLibrary* raw_library, std::unique_ptr library) { + library_storage[raw_library] = std::move(library); +} + +bool removeFunctionHandle(ETMetalKernelFunction* raw_function) { + auto it = function_storage.find(raw_function); + if (it != function_storage.end()) { + function_storage.erase(it); + return true; + } + return false; +} + +bool removeLibraryHandle(ETMetalShaderLibrary* raw_library) { + auto it = library_storage.find(raw_library); + if (it != library_storage.end()) { + library_storage.erase(it); + return true; + } + return false; +} + +// ======================= +// Global Stream Access Functions +// ======================= + +ETMetalStream* getCurrentMetalStream() { + if (!currentStream_) { + currentStream_ = ETMetalStream::getDefaultStream(); + } + return currentStream_; +} + +void setCurrentMetalStream(ETMetalStream* stream) { + currentStream_ = stream; +} + +// ======================= +// Metal Stream Synchronization Functions +// ======================= + +void synchronize_metal_stream() { + @autoreleasepool { + // Use the ETMetalStream for proper synchronization + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + + ET_LOG(Debug, "synchronize_metal_stream: Stream synchronized with COMMIT_AND_WAIT"); + } +} + +void synchronize_metal_stream_with_type(int sync_type) { + @autoreleasepool { + ETMetalStream* stream = getCurrentMetalStream(); + SyncType syncTypeEnum = static_cast(sync_type); + stream->synchronize(syncTypeEnum); + + ET_LOG(Debug, "synchronize_metal_stream_with_type: Stream synchronized with SyncType %d", sync_type); + } +} + +} // namespace metal +} // namespace backends +} // namespace executorch From dfa435ada3c8b2eacf0245eb55e3d946e894d75f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:39 -0400 Subject: [PATCH 06/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 453 ++++++++++++++++++ backends/apple/metal/runtime/shims/memory.h | 73 +++ 2 files changed, 526 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/memory.cpp create mode 100644 backends/apple/metal/runtime/shims/memory.h diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp new file mode 100644 index 00000000000..2bda93e18a4 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -0,0 +1,453 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include // Ensure we have int64_t, int32_t definitions +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Import all from aoti namespace +using namespace executorch::backends::aoti; + +// Global storage for tensors and their metadata +std::unordered_set> tensors; +std::unordered_map is_tensor_own_memory; + +extern "C" { + +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: entered"); + + (void)device_type; + (void)opaque_metadata; + (void)layout; + (void)opaque_metadata_size; + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + data != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null"); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Handle storage offset by adjusting the data pointer + void* adjusted_data = static_cast(data) + + (storage_offset * dtype_to_element_size(dtype)); + + ET_LOG( + Debug, + "aoti_torch_create_tensor_from_blob_v2: original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data, + storage_offset, + dtype_to_element_size(dtype), + adjusted_data); + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: contiguous tensor"); + } else { + ET_LOG( + Debug, "aoti_torch_create_tensor_from_blob_v2: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + auto tensor = executorch::extension::from_blob( + adjusted_data, sizes, strides, dtype_to_scalar_type(dtype)); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create tensor from blob"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = false; + + ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch_empty_strided: entered"); + + // This requires us to reserve device memory and put it into a ETensor + void* ptr; + int64_t numel = 1; + for (int i = 0; i < ndim; i++) { + numel *= sizes_ptr[i]; + } + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + size_t element_size = dtype_to_element_size(dtype); + ET_CHECK_OR_RETURN_ERROR( + element_size != 0, + InvalidArgument, + "Invalid element size for dtype: %d", + dtype); + int64_t nbytes = numel * element_size; + + if (device_type == 2) { // Metal/MPS + ptr = metal_allocate_buffer(nbytes); + if (!ptr) { + ET_LOG(Error, "Failed to allocate %lld bytes on Metal device", nbytes); + return Error::MemoryAllocationFailed; + } + } else if (device_type == 0) { // cpu + // Ensure 16-byte alignment for CPU memory to match device requirements + int result = posix_memalign(&ptr, 16, nbytes); + ET_CHECK_OR_RETURN_ERROR( + result == 0, + MemoryAllocationFailed, + "Failed to allocate aligned CPU memory"); + ET_CHECK_OR_RETURN_ERROR( + ptr != nullptr, + MemoryAllocationFailed, + "Failed to call posix_memalign"); + ET_LOG(Debug, "Allocated %lld bytes on CPU", nbytes); + } else { + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "Need to implement empty_strided for non-CUDA non-CPU device type %d", + device_type); + } + + // ETensor sizes + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + + // ETensor strides + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Log if the tensor is contiguous + if (is_contiguous_tensor(sizes, strides)) { + ET_LOG(Debug, "aoti_torch_empty_strided: contiguous tensor"); + } else { + ET_LOG(Debug, "aoti_torch_empty_strided: non-contiguous tensor"); + } + + // ETensor creation + // Note: We're NOT copying the data, just wrapping it + executorch::aten::ScalarType scalar_type = dtype_to_scalar_type(dtype); + auto tensor = + executorch::extension::from_blob(ptr, sizes, strides, scalar_type); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + *ret_new_tensor = tensor.get(); + is_tensor_own_memory[tensor.get()] = true; + + ET_LOG(Debug, "aoti_torch_empty_strided: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) { + ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered"); + // Find tensor in the set + for (auto it = tensors.begin(); it != tensors.end(); ++it) { + if (it->get() == tensor) { + auto tensor_ptr = *it; + + // Check ownership before cleaning up + auto ownership_it = is_tensor_own_memory.find(tensor); + bool owns_memory = (ownership_it != is_tensor_own_memory.end()) + ? ownership_it->second + : false; + + // Clean up ownership metadata + is_tensor_own_memory.erase(tensor); + + if (owns_memory) { + // et tensor owns the memory; need to free it manually + void* data_ptr = tensor_ptr->mutable_data_ptr(); + + // Check if it's Metal GPU memory + if (metal_is_device_pointer(data_ptr)) { + // This is Metal GPU memory - the Metal helper will handle cleanup + // Metal buffers are automatically managed by ARC when the buffer is + // released + tensors.erase(it); + ET_LOG( + Debug, + "aoti_torch_delete_tensor_object: successfull (Metal GPU memory)"); + return Error::Ok; + } + + // This is CPU memory - free immediately + free(data_ptr); + } + // else: Don't free memory since the tensor doesn't own it + + // Remove from set (this will call the destructor if it's the last + // reference) + tensors.erase(it); + ET_LOG( + Debug, "aoti_torch_delete_tensor_object: successfull (CPU memory)"); + return Error::Ok; + } + } + ET_LOG(Error, "Didn't find tensor %p", tensor); + return Error::InvalidArgument; +} + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking) { + ET_LOG(Debug, "aoti_torch_copy_: entered"); + + (void)non_blocking; + + // Check for null pointers first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, + InvalidArgument, + "aoti_torch_copy_ failed: src tensor is null"); + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype)); + + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype)); + + // Check dtype compatibility - both tensors must have the same dtype + ET_CHECK_OR_RETURN_ERROR( + self_dtype == src_dtype, + InvalidArgument, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + ET_CHECK_OR_RETURN_ERROR( + self_numel == src_numel, + InvalidArgument, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + bool srcIsDevice = false; + bool dstIsDevice = false; + + // Check if pointers are Metal device pointers + if (!srcIsDevice) { + srcIsDevice = metal_is_device_pointer(const_cast(src->data_ptr())); + } + if (!dstIsDevice) { + dstIsDevice = metal_is_device_pointer(self->mutable_data_ptr()); + } + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + // TODO: This should be improved to catch cases like (4, 1, 5) -> (4, 5) + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + int result = metal_copy_memory( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + srcIsDevice, + dstIsDevice); + if (result != 0) { + ET_LOG(Error, "metal_copy_memory failed with status %d", result); + return Error::Internal; + } + } else { + ET_LOG(Error, "Layout conversion not supported"); + return Error::NotImplemented; + } + + ET_LOG(Debug, "aoti_torch_copy_: successfull"); + return Error::Ok; +} + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor) { + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: entered"); + + // Validate input parameters first + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: self tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: sizes_ptr is null"); + + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); + + // Get the dtype from the source tensor + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype)); + + // Validate dtype using SupportedDTypes + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + int32_t device_type = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type)); + + int32_t device_index = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index)); + + // Get the base data pointer from the source tensor + void* base_data_ptr = self->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + base_data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Calculate new tensor size in elements for logging + int64_t new_numel = 1; + for (int64_t i = 0; i < ndim; i++) { + new_numel *= sizes_ptr[i]; + } + + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld", + base_data_ptr, + new_numel, + storage_offset); + + // Create a new tensor view that shares the same underlying storage + // This is the correct way to implement reinterpret_tensor - as a view, not a + // copy + AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2( + base_data_ptr, // Same underlying data pointer + ndim, // New dimensions + sizes_ptr, // New sizes + strides_ptr, // New strides + storage_offset, // Storage offset (will be handled properly now) + dtype, + device_type, + device_index, + ret_new_tensor, + 0, // layout (default) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_err != Error::Ok) { + ET_LOG(Error, "failed to create reinterpreted tensor view"); + return create_err; + } + + ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); + return Error::Ok; +} + +// Cleanup function for clearing global state +void cleanup_memory() { + is_tensor_own_memory.clear(); + if (!tensors.empty()) { + ET_LOG(Error, "Warning: tensors not empty during cleanup"); + } + + // Clean up Metal resources + metal_cleanup_resources(); +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/memory.h b/backends/apple/metal/runtime/shims/memory.h new file mode 100644 index 00000000000..47fb6352b50 --- /dev/null +++ b/backends/apple/metal/runtime/shims/memory.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Global storage declarations +extern std::unordered_map is_tensor_own_memory; +extern std::unordered_set> tensors; + +// Memory-related operations +AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void* data, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor, + int32_t layout, + const uint8_t* opaque_metadata, + int64_t opaque_metadata_size); + +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + AOTITensorHandle* ret_new_tensor); + +AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor); + +AOTITorchError aoti_torch_copy_( + AOTITensorHandle self, + AOTITensorHandle src, + int32_t non_blocking); + +AOTITorchError aoti_torch__reinterpret_tensor( + AOTITensorHandle self, + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int64_t storage_offset, + AOTITensorHandle* ret_new_tensor); + +void cleanup_memory(); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch From 648ee077cce4f779c3a7353e43642b680c962f94 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 10 Oct 2025 17:01:44 -0400 Subject: [PATCH 07/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/shim_mps.h | 118 ++++ .../apple/metal/runtime/shims/shim_mps.mm | 540 ++++++++++++++++++ 2 files changed, 658 insertions(+) create mode 100644 backends/apple/metal/runtime/shims/shim_mps.h create mode 100644 backends/apple/metal/runtime/shims/shim_mps.mm diff --git a/backends/apple/metal/runtime/shims/shim_mps.h b/backends/apple/metal/runtime/shims/shim_mps.h new file mode 100644 index 00000000000..94611b016ae --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace backends { +namespace metal { + +struct AOTIMetalKernelFunctionOpaque; +using AOTIMetalKernelFunctionHandle = AOTIMetalKernelFunctionOpaque*; + +struct AOTIMetalShaderLibraryOpaque; +using AOTIMetalShaderLibraryHandle = AOTIMetalShaderLibraryOpaque*; + +#ifdef __cplusplus +extern "C" { +#endif + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle); + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle); + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle); + +// MetalKernelFunction functions +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func); + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor); + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val); + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length); + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size); + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size); + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size); + +// Memory management functions +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes); + +AOTITorchError aoti_torch_mps_free(void* ptr); + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start); + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + +// C callback function type for command block execution +typedef void (*aoti_torch_mps_command_block_callback_t)( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data); + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data); + +#ifdef __cplusplus +} // extern "C" +#endif + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/shim_mps.mm b/backends/apple/metal/runtime/shims/shim_mps.mm new file mode 100644 index 00000000000..e5e7d8c0dc9 --- /dev/null +++ b/backends/apple/metal/runtime/shims/shim_mps.mm @@ -0,0 +1,540 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#import +#import +#import +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Declare the global mapping from et_metal.mm +extern std::unordered_map> ptr_to_mtl_buffer; + +extern "C" { + +// MetalShaderLibrary functions +AOTITorchError aoti_torch_mps_create_shader_library( + const char* metal_shader_source, + AOTIMetalShaderLibraryHandle* library_handle) { + + if (!metal_shader_source || !library_handle) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: null arguments"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto library = std::make_unique(std::string(metal_shader_source)); + auto* raw_library = library.get(); + + // Store the unique_ptr to keep the object alive + storeLibraryHandle(raw_library, std::move(library)); + + // Return raw pointer to match existing API + *library_handle = reinterpret_cast(raw_library); + + ET_LOG(Debug, "aoti_torch_mps_create_shader_library: Created shader library %p", raw_library); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_create_shader_library: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_delete_shader_library( + AOTIMetalShaderLibraryHandle library_handle) { + + if (!library_handle) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: null library handle"); + return Error::InvalidArgument; + } + + try { + auto* library = reinterpret_cast(library_handle); + if (removeLibraryHandle(library)) { + ET_LOG(Debug, "aoti_torch_mps_delete_shader_library: Deleted shader library %p", library); + } else { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: Library not found in storage"); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_delete_shader_library: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_get_kernel_function( + AOTIMetalShaderLibraryHandle library_handle, + const char* kernel_name, + AOTIMetalKernelFunctionHandle* function_handle) { + + if (!library_handle || !kernel_name || !function_handle) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: null arguments"); + return Error::InvalidArgument; + } + + try { + auto* library = reinterpret_cast(library_handle); + auto function_shared_ptr = library->getKernelFunction(std::string(kernel_name)); + if (!function_shared_ptr) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: Failed to get kernel function '%s'", kernel_name); + return Error::Internal; + } + + auto* raw_function = function_shared_ptr.get(); + + // Store the shared_ptr to keep the object alive + storeFunctionHandle(raw_function, function_shared_ptr); + + // Return raw pointer to match existing API + *function_handle = reinterpret_cast(raw_function); + + ET_LOG(Debug, "aoti_torch_mps_get_kernel_function: Got kernel function '%s' -> %p", kernel_name, raw_function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_get_kernel_function: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_start_encoding( + AOTIMetalKernelFunctionHandle func) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps_start_encoding: Started encoding for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_start_encoding exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_start_encoding: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_set_arg_tensor( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + AOTITensorHandle tensor) { + + if (!func || !tensor) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: null function handle or tensor"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto* function = reinterpret_cast(func); + auto* et_tensor = reinterpret_cast(tensor); + + function->setArg(idx, *et_tensor); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_tensor: Set tensor argument at index %u", idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_tensor: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_set_arg_int( + AOTIMetalKernelFunctionHandle func, + unsigned idx, + int64_t val) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->setArg(idx, val); + + ET_LOG(Debug, "aoti_torch_mps_set_arg_int: Set int64_t value %lld at index %u", val, idx); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_set_arg_int: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - single value versions +AOTITorchError aoti_torch_mps_dispatch_single( + AOTIMetalKernelFunctionHandle func, + uint64_t length) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingle(length); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single: Dispatched function %p with length %llu", function, length); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_single_with_group_size( + AOTIMetalKernelFunctionHandle func, + uint64_t length, + uint64_t group_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchSingleWithGroupSize(length, group_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_single_with_group_size: Dispatched function %p with length %llu, group size %llu", function, length, group_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_single_with_group_size: unknown exception"); + return Error::Internal; + } +} + +// Pure C dispatch functions - array versions +AOTITorchError aoti_torch_mps_dispatch_array( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArray(length, length_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_dispatch_array_with_group_size( + AOTIMetalKernelFunctionHandle func, + const uint64_t* length, + size_t length_size, + const uint64_t* group_size, + size_t group_size_size) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: null function handle"); + return Error::InvalidArgument; + } + + try { + auto* function = reinterpret_cast(func); + function->dispatchArrayWithGroupSize(length, length_size, group_size, group_size_size); + + ET_LOG(Debug, "aoti_torch_mps_dispatch_array_with_group_size: Dispatched function %p with %zu dimensions", function, length_size); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_dispatch_array_with_group_size: unknown exception"); + return Error::Internal; + } +} + +AOTITorchError aoti_torch_mps_malloc(void** buffer, size_t num_bytes) { + if (num_bytes == 0) { + *buffer = nullptr; + return Error::Ok; + } + + if (!buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: null buffer pointer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to get Metal device"); + return Error::Internal; + } + + id metal_buffer = [device newBufferWithLength:num_bytes + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared]; + if (!metal_buffer) { + ET_LOG(Error, "aoti_torch_mps_malloc: Failed to allocate Metal buffer of size %zu", num_bytes); + return Error::Internal; + } + + // FIX: Return contents pointer, not buffer object + void* contents_ptr = [metal_buffer contents]; + ptr_to_mtl_buffer[contents_ptr] = metal_buffer; // Map contents to buffer + *buffer = contents_ptr; // Return contents pointer + + ET_LOG(Debug, "aoti_torch_mps_malloc: Allocated Metal buffer %p with contents %p of size %zu", + metal_buffer, contents_ptr, num_bytes); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_malloc exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_malloc: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_free(void* ptr) { + if (!ptr) { + return Error::Ok; // Nothing to free + } + + @autoreleasepool { + try { + // FIX: ptr is now the contents pointer, not the buffer object + // Look up the buffer from the mapping and clean up + auto it = ptr_to_mtl_buffer.find(ptr); + if (it != ptr_to_mtl_buffer.end()) { + id metal_buffer = it->second; + [metal_buffer release]; + ptr_to_mtl_buffer.erase(it); + ET_LOG(Debug, "aoti_torch_mps_free: Freed Metal buffer for contents %p", ptr); + } else { + ET_LOG(Error, "aoti_torch_mps_free: Buffer not found for contents pointer %p", ptr); + return Error::InvalidArgument; + } + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_free exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_free: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_memcpy( + void* buffer, + size_t constant_offset, + size_t bytes_read, + size_t data_size, + uint8_t* constants_start) { + + if (!buffer || !constants_start) { + ET_LOG(Error, "aoti_torch_mps_memcpy: null buffer or constants_start"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // FIX: buffer is now the contents pointer, not the buffer object + auto buffer_pointer = static_cast(buffer); + + memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_memcpy: Failed to get Metal device"); + return Error::Internal; + } + id subBuffer = [device newBufferWithBytesNoCopy:buffer_pointer + constant_offset + length:data_size + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared + deallocator:nil]; + + if (constant_offset != 0) { + ptr_to_mtl_buffer[buffer_pointer + constant_offset] = subBuffer; // Map contents to buffer + } + + ET_LOG(Debug, "aoti_torch_mps_memcpy: Copied %zu bytes from offset %zu to buffer offset %zu", + data_size, bytes_read, constant_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_memcpy exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_memcpy: unknown exception"); + return Error::Internal; + } + } +} + +AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset) { + + if (!src_buffer || !dst_buffer) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: null buffer"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + auto src_mtl_buffer = (id)src_buffer; + auto dst_mtl_buffer = (id)dst_buffer; + + uint8_t* src_contents = static_cast([src_mtl_buffer contents]); + uint8_t* dst_contents = static_cast([dst_mtl_buffer contents]); + + if (!src_contents || !dst_contents) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: Failed to get buffer contents"); + return Error::Internal; + } + + memcpy(dst_contents + dst_offset, src_contents + src_offset, data_size); + + ET_LOG(Debug, "aoti_torch_mps_copy_buffer: Copied %zu bytes from src+%zu to dst+%zu", + data_size, src_offset, dst_offset); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_copy_buffer: unknown exception"); + return Error::Internal; + } + } +} + +// Shared callback function for std::function trampoline +void aoti_torch_mps_shared_callback( + AOTIMetalKernelFunctionHandle func, + void* user_data) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Called with func=%p, user_data=%p", func, user_data); + + auto* function_wrapper = static_cast*>(user_data); + if (function_wrapper) { + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Calling function wrapper"); + (*function_wrapper)(func); + ET_LOG(Debug, "aoti_torch_mps_shared_callback: Function wrapper completed"); + } else { + ET_LOG(Error, "aoti_torch_mps_shared_callback: null function wrapper"); + } +} + +// Pure C version using function pointer and user data for trampoline pattern +AOTITorchError aoti_torch_mps_run_command_block( + AOTIMetalKernelFunctionHandle func, + aoti_torch_mps_command_block_callback_t callback, + void* user_data) { + + if (!func) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null function handle"); + return Error::InvalidArgument; + } + + if (!callback) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: null callback"); + return Error::InvalidArgument; + } + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Starting command block for function %p, callback %p, user_data %p", + func, callback, user_data); + + try { + auto* function = reinterpret_cast(func); + function->runCommandBlock([callback, func, user_data]() { + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Inside lambda, calling callback"); + callback(func, user_data); + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Callback completed"); + }); + + ET_LOG(Debug, "aoti_torch_mps_run_command_block: Executed command block for function %p", function); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_run_command_block exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_run_command_block: unknown exception"); + return Error::Internal; + } +} + +} // extern "C" + + +} // namespace metal +} // namespace backends +} // namespace executorch From ca5f1e52300560ba3dab33ed247afdb0ea36a30a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Sat, 11 Oct 2025 15:47:38 -0400 Subject: [PATCH 08/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/utils.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 484158e9027..bc8c0483e9d 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -65,13 +65,12 @@ std::vector convert_strides_to_vector( std::vector strides(ndim); if (strides_ptr != nullptr) { - // Use provided strides. it is ok if provided strides here is not contiguous - // strides since it will be used internally in CUDA delegate. + // Use provided strides. for (int64_t i = 0; i < ndim; i++) { strides[i] = static_cast(strides_ptr[i]); } } else { - // Calculate strides from sizes using ExecutorTorch's algorithm + // Calculate strides from sizes. if (ndim > 0) { strides[ndim - 1] = static_cast( 1); // Last dimension has stride 1 From 81c45885a294719d96cef8d6484a4a36f647734f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 13 Oct 2025 18:47:10 -0400 Subject: [PATCH 09/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 2bda93e18a4..4b8b6cda4e0 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include // Ensure we have int64_t, int32_t definitions #include #include @@ -144,7 +143,8 @@ AOTITorchError aoti_torch_empty_strided( dtype); int64_t nbytes = numel * element_size; - if (device_type == 2) { // Metal/MPS + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + if (device_type == mps_device_type) { ptr = metal_allocate_buffer(nbytes); if (!ptr) { ET_LOG(Error, "Failed to allocate %lld bytes on Metal device", nbytes); From 422e4baf66b007ec529f78e44d4691ffd9e530ed Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 13 Oct 2025 19:37:57 -0400 Subject: [PATCH 10/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/memory.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index 4b8b6cda4e0..83250f308bb 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -143,7 +143,7 @@ AOTITorchError aoti_torch_empty_strided( dtype); int64_t nbytes = numel * element_size; - int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 + int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13 if (device_type == mps_device_type) { ptr = metal_allocate_buffer(nbytes); if (!ptr) { From f46adc5149a168fec69d18bfebe7fcd39be37ce5 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 14 Oct 2025 21:36:07 -0400 Subject: [PATCH 11/14] Update [ghstack-poisoned] --- backends/cuda/runtime/shims/memory.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 6fe315ba8ee..fe8ccf07281 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; using executorch::backends::aoti::aoti_torch_get_strides; +using executorch::backends::aoti::convert_sizes_to_vector; +using executorch::backends::aoti::convert_strides_to_vector; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; From 71f87b691d65e222957b714e40e7d2e657075cb8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:32:04 -0400 Subject: [PATCH 12/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 107 ++++++++++++++++++ .../apple/metal/runtime/shims/et_metal.mm | 22 +++- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index c18ad513a3a..a1c8c684131 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -77,6 +77,35 @@ enum class SyncType { // ======================= // ETMetalShaderLibrary - ExecuTorch Metal shader library management // ======================= + +/** + * @class ETMetalShaderLibrary + * @brief Manages Metal shader library compilation and kernel function retrieval. + * + * This class provides a high-level interface for compiling Metal shading language + * source code into a Metal library and creating compute pipeline states for + * kernel functions. It handles the creation and caching of Metal compute pipeline + * states and functions, which should be reused across multiple kernel dispatches. + * + * The class automatically compiles the provided shader source code upon construction + * and maintains an internal cache of compute pipeline states for different kernel + * functions to avoid redundant compilation. + * + * Example usage: + * @code + * std::string shaderSource = R"( + * #include + * using namespace metal; + * kernel void my_kernel(device float* data [[buffer(0)]], + * uint tid [[thread_position_in_grid]]) { + * data[tid] = data[tid] * 2.0; + * } + * )"; + * + * ETMetalShaderLibrary library(shaderSource); + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * @endcode + */ class ETMetalShaderLibrary { public: ETMetalShaderLibrary(const std::string& source); @@ -103,6 +132,45 @@ class ETMetalShaderLibrary { // ======================= // ETMetalKernelFunction - ExecuTorch Metal kernel function execution // ======================= + +/** + * @class ETMetalKernelFunction + * @brief Represents a Metal compute kernel function ready for execution. + * + * This class encapsulates a Metal compute pipeline state and function, providing + * a high-level interface for setting kernel arguments and dispatching compute + * work to the GPU. It handles the encoding of compute commands and manages the + * interaction with Metal's compute command encoder. + * + * The class supports different dispatch patterns: + * - Single-dimension dispatch for linear workloads + * - Multi-dimensional dispatch for grid-based workloads + * - Custom thread group sizes for performance optimization + * + * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) + * or scalar values. The class handles the encoding of these arguments + * into the compute command encoder. + * + * Example usage: + * @code + * // Get kernel function from library + * auto kernelFunction = library.getKernelFunction("vector_add"); + * + * // Start encoding commands + * kernelFunction->startEncoding(); + * + * // Set tensor arguments + * kernelFunction->setArg(0, inputTensorA); + * kernelFunction->setArg(1, inputTensorB); + * kernelFunction->setArg(2, outputTensor); + * + * // Set scalar argument + * kernelFunction->setArg(3, static_cast(numElements)); + * + * // Dispatch for linear workload + * kernelFunction->dispatchSingle(numElements); + * @endcode + */ class ETMetalKernelFunction { public: ETMetalKernelFunction(MTLComputePipelineState_t cps, MTLFunction_t func); @@ -132,6 +200,45 @@ class ETMetalKernelFunction { // ======================= // ETMetalStream - Metal command buffer and synchronization management // ======================= + +/** + * @class ETMetalStream + * @brief Manages Metal compute command streams and provides GPU synchronization. + * + * This class serves as the central management hub for Metal GPU operations, providing + * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, + * compute command encoder management, and various synchronization patterns required for + * efficient GPU computation. + * + * Key features: + * - Lazy command buffer and encoder creation for optimal resource usage + * - Thread-safe operations using serial dispatch queues + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Kernel coalescing to batch multiple operations efficiently + * - MPSGraph integration for high-level neural network operations + * - Memory operations (copy, fill) with GPU acceleration via blit encoders + * + * The stream follows PyTorch's MPS stream design patterns, providing similar semantics + * for command buffer management and synchronization. + * + * Example usage: + * @code + * // Get current stream (typically the default stream) + * ETMetalStream* stream = getCurrentMetalStream(); + * + * // Execute kernel operations (handled automatically) + * auto kernelFunction = library.getKernelFunction("my_kernel"); + * kernelFunction->startEncoding(); + * kernelFunction->setArg(0, inputTensor); + * kernelFunction->dispatchSingle(numElements); + * + * // Synchronize to ensure completion + * stream->synchronize(SyncType::COMMIT_AND_WAIT); + * + * // Copy between GPU buffers using blit encoder + * stream->copy(srcBuffer, dstBuffer, numBytes, 0, 0, SyncType::COMMIT); + * @endcode + */ class ETMetalStream { public: ETMetalStream(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index 5afcf761d56..f76146ab783 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -743,6 +743,26 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev void ETMetalStream::copy(id srcBuffer, id dstBuffer, size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) { + + if (length == 0) { + return; + + // Check that offsets are within buffer bounds before copying + if (!srcBuffer || !dstBuffer) { + ET_LOG(Error, "ETMetalStream::copy: Source or destination buffer is nil"); + return; + } + NSUInteger srcBufferLength = [srcBuffer length]; + NSUInteger dstBufferLength = [dstBuffer length]; + if (srcOffset + length > srcBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Source offset (%zu) + length (%zu) exceeds source buffer size (%zu)", srcOffset, length, srcBufferLength); + return; + } + if (dstOffset + length > dstBufferLength) { + ET_LOG(Error, "ETMetalStream::copy: Destination offset (%zu) + length (%zu) exceeds destination buffer size (%zu)", dstOffset, length, dstBufferLength); + return; + } + dispatch_sync(serialQueue_, ^{ @autoreleasepool { endKernelCoalescing(); @@ -792,8 +812,6 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev targetOperations:nil resultsDictionary:results executionDescriptor:nil]; - - //synchronize(syncType); } }); } From 95a70247fded8a432c69a08d64e9d891a2c8a2f4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:40:46 -0400 Subject: [PATCH 13/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.h | 50 ++++++++++--------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index a1c8c684131..75f79e5139c 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -80,16 +80,18 @@ enum class SyncType { /** * @class ETMetalShaderLibrary - * @brief Manages Metal shader library compilation and kernel function retrieval. + * @brief Manages Metal shader library compilation and kernel function + * retrieval. * - * This class provides a high-level interface for compiling Metal shading language - * source code into a Metal library and creating compute pipeline states for - * kernel functions. It handles the creation and caching of Metal compute pipeline - * states and functions, which should be reused across multiple kernel dispatches. + * This class provides a high-level interface for compiling Metal shading + * language source code into a Metal library and creating compute pipeline + * states for kernel functions. It handles the creation and caching of Metal + * compute pipeline states and functions, which should be reused across multiple + * kernel dispatches. * - * The class automatically compiles the provided shader source code upon construction - * and maintains an internal cache of compute pipeline states for different kernel - * functions to avoid redundant compilation. + * The class automatically compiles the provided shader source code upon + * construction and maintains an internal cache of compute pipeline states for + * different kernel functions to avoid redundant compilation. * * Example usage: * @code @@ -137,18 +139,18 @@ class ETMetalShaderLibrary { * @class ETMetalKernelFunction * @brief Represents a Metal compute kernel function ready for execution. * - * This class encapsulates a Metal compute pipeline state and function, providing - * a high-level interface for setting kernel arguments and dispatching compute - * work to the GPU. It handles the encoding of compute commands and manages the - * interaction with Metal's compute command encoder. + * This class encapsulates a Metal compute pipeline state and function, + * providing a high-level interface for setting kernel arguments and dispatching + * compute work to the GPU. It handles the encoding of compute commands and + * manages the interaction with Metal's compute command encoder. * * The class supports different dispatch patterns: * - Single-dimension dispatch for linear workloads * - Multi-dimensional dispatch for grid-based workloads * - Custom thread group sizes for performance optimization * - * Kernel arguments can be set using tensors (which will be mapped to Metal buffers) - * or scalar values. The class handles the encoding of these arguments + * Kernel arguments can be set using tensors (which will be mapped to Metal + * buffers) or scalar values. The class handles the encoding of these arguments * into the compute command encoder. * * Example usage: @@ -203,23 +205,25 @@ class ETMetalKernelFunction { /** * @class ETMetalStream - * @brief Manages Metal compute command streams and provides GPU synchronization. + * @brief Manages Metal compute command streams and provides GPU + * synchronization. * - * This class serves as the central management hub for Metal GPU operations, providing - * a stream-based abstraction similar to CUDA streams. It handles command buffer lifecycle, - * compute command encoder management, and various synchronization patterns required for - * efficient GPU computation. + * This class serves as the central management hub for Metal GPU operations, + * providing a stream-based abstraction similar to CUDA streams. It handles + * command buffer lifecycle, compute command encoder management, and various + * synchronization patterns required for efficient GPU computation. * * Key features: * - Lazy command buffer and encoder creation for optimal resource usage * - Thread-safe operations using serial dispatch queues - * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, COMMIT_AND_CONTINUE) + * - Multiple synchronization modes (COMMIT, COMMIT_AND_WAIT, + * COMMIT_AND_CONTINUE, etc.) * - Kernel coalescing to batch multiple operations efficiently - * - MPSGraph integration for high-level neural network operations + * - MPSGraph integration for executing fall back operations (mm, conv, sdpa) * - Memory operations (copy, fill) with GPU acceleration via blit encoders * - * The stream follows PyTorch's MPS stream design patterns, providing similar semantics - * for command buffer management and synchronization. + * The stream follows PyTorch's MPS stream design patterns, providing similar + * semantics for command buffer management and synchronization. * * Example usage: * @code From d37e7efb19229e57e5dd314c5c06345bb68f4c59 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 15 Oct 2025 15:51:31 -0400 Subject: [PATCH 14/14] Update [ghstack-poisoned] --- backends/apple/metal/runtime/shims/et_metal.mm | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index f76146ab783..fdca0a28cf3 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -746,6 +746,7 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev if (length == 0) { return; + } // Check that offsets are within buffer bounds before copying if (!srcBuffer || !dstBuffer) {