diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 78373513439..b24fcaac864 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -100,6 +100,64 @@ inline bool is_tensor_contiguous( } // extern "C" +// Utility function to convert sizes pointer to vector +inline std::vector convert_sizes_to_vector( + int64_t ndim, + const int64_t* sizes_ptr) { + std::vector sizes(ndim); + for (int i = 0; i < ndim; i++) { + sizes[i] = static_cast(sizes_ptr[i]); + } + return sizes; +} + +// Utility function to convert strides pointer to vector or calculate from sizes +inline std::vector convert_strides_to_vector( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr) { + std::vector strides(ndim); + + if (strides_ptr != nullptr) { + // Use provided strides. + for (int64_t i = 0; i < ndim; i++) { + strides[i] = static_cast(strides_ptr[i]); + } + } else { + // Calculate strides from sizes. + if (ndim > 0) { + strides[ndim - 1] = static_cast( + 1); // Last dimension has stride 1 + for (int64_t i = ndim - 2; i >= 0; i--) { + if (sizes_ptr[i + 1] == 0) { + strides[i] = strides[i + 1]; // Copy stride when size is 0 + } else { + strides[i] = static_cast( + static_cast(strides[i + 1]) * sizes_ptr[i + 1]); + } + } + } + } + return strides; +} + +// Check if tensor is in contiguous memory format (NCHW for 4D tensors) +// Contiguous format means strides decrease from left to right: +// For NCHW: strides = [C*H*W, H*W, W, 1] +inline bool is_contiguous_tensor( + std::vector& sizes, + std::vector& strides) { + int64_t ndim = static_cast(strides.size()); + int64_t expected_stride = 1; + for (int64_t i = ndim - 1; i >= 0; i--) { + if (strides[i] != expected_stride) { + return false; + } + expected_stride *= sizes[i]; + } + return true; +} + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.cpp b/backends/apple/metal/runtime/shims/tensor_attribute.cpp new file mode 100644 index 00000000000..34e0329fdc9 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type constant +__attribute__((__visibility__("default"))) int32_t +aoti_torch_device_type_mps() { + return 13; // Consistent with c10/core/DeviceType.h +} + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type) { + *ret_device_type = aoti_torch_device_type_mps(); + return Error::Ok; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/tensor_attribute.h b/backends/apple/metal/runtime/shims/tensor_attribute.h new file mode 100644 index 00000000000..8d2a3dde361 --- /dev/null +++ b/backends/apple/metal/runtime/shims/tensor_attribute.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Metal-specific device type function +int32_t aoti_torch_device_type_mps(); + +// Override aoti_torch_get_device_type to return MPS device type +AOTITorchError aoti_torch_get_device_type( + AOTITensorHandle tensor, + int32_t* ret_device_type); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/types.h b/backends/apple/metal/runtime/shims/types.h new file mode 100644 index 00000000000..07d377d7499 --- /dev/null +++ b/backends/apple/metal/runtime/shims/types.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Common using declarations for ExecutorTorch types +using executorch::runtime::Error; +using executorch::runtime::etensor::Tensor; + +extern "C" { + +// Common AOTI type aliases +// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility +using AOTITensorHandle = Tensor*; +using AOTIRuntimeError = Error; +using AOTITorchError = Error; + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp new file mode 100644 index 00000000000..061360a4e28 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype) { + switch (dtype) { + case static_cast(SupportedDTypes::INT64): + case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BFLOAT16): + return true; + default: + return false; + } +} + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype) { + if (is_dtype_supported_in_et_metal(dtype)) { + return Error::Ok; + } + + ET_LOG( + Error, + "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + dtype, + static_cast(SupportedDTypes::INT64), + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; +} + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h new file mode 100644 index 00000000000..974832fa365 --- /dev/null +++ b/backends/apple/metal/runtime/shims/utils.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace metal { + +// Enum for supported data types in et-metal backend +enum class SupportedDTypes : int32_t { + // UINT8 = 0, // PyTorch's uint8 dtype code + // INT8 = 1, // PyTorch's int8 dtype code + // INT16 = 2, // PyTorch's int16 dtype code + // INT32 = 3, // PyTorch's int32 dtype code + INT64 = 4, // PyTorch's int64 dtype code + // FLOAT16 = 5, // PyTorch's float16 dtype code + FLOAT32 = 6, // PyTorch's float32 dtype code + // FLOAT64 = 7, // PyTorch's float64 dtype code + // BOOL = 11, // PyTorch's bool dtype code + BFLOAT16 = 15 // PyTorch's bfloat16 dtype code +}; + +extern "C" { + +// Helper function to check if a dtype is supported in Metal backend +bool is_dtype_supported_in_et_metal(int32_t dtype); + +// Metal-specific dtype validation utility function +AOTITorchError validate_dtype(int32_t dtype); + +} // extern "C" + +} // namespace metal +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 6fe315ba8ee..fe8ccf07281 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; using executorch::backends::aoti::aoti_torch_get_strides; +using executorch::backends::aoti::convert_sizes_to_vector; +using executorch::backends::aoti::convert_strides_to_vector; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp index e18bf142b5c..1cefca99c2a 100644 --- a/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 22ed81bffd9..04c1a43721a 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -71,48 +71,6 @@ enum class SupportedDevices : int32_t { CUDA = 1, // CUDA device }; -// Utility function to convert sizes pointer to vector -inline std::vector convert_sizes_to_vector( - int64_t ndim, - const int64_t* sizes_ptr) { - std::vector sizes(ndim); - for (int i = 0; i < ndim; i++) { - sizes[i] = static_cast(sizes_ptr[i]); - } - return sizes; -} - -// Utility function to convert strides pointer to vector or calculate from sizes -inline std::vector convert_strides_to_vector( - int64_t ndim, - const int64_t* sizes_ptr, - const int64_t* strides_ptr) { - std::vector strides(ndim); - - if (strides_ptr != nullptr) { - // Use provided strides. it is ok if provided strides here is not contiguous - // strides since it will be used internally in CUDA delegate. - for (int64_t i = 0; i < ndim; i++) { - strides[i] = static_cast(strides_ptr[i]); - } - } else { - // Calculate strides from sizes using ExecutorTorch's algorithm - if (ndim > 0) { - strides[ndim - 1] = static_cast( - 1); // Last dimension has stride 1 - for (int64_t i = ndim - 2; i >= 0; i--) { - if (sizes_ptr[i + 1] == 0) { - strides[i] = strides[i + 1]; // Copy stride when size is 0 - } else { - strides[i] = static_cast( - static_cast(strides[i + 1]) * sizes_ptr[i + 1]); - } - } - } - } - return strides; -} - extern "C" { using executorch::runtime::Error; // Common AOTI type aliases