Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,64 @@ inline bool is_tensor_contiguous(

} // extern "C"

// Utility function to convert sizes pointer to vector
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
int64_t ndim,
const int64_t* sizes_ptr) {
std::vector<executorch::aten::SizesType> sizes(ndim);
for (int i = 0; i < ndim; i++) {
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
}
return sizes;
}

// Utility function to convert strides pointer to vector or calculate from sizes
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr) {
std::vector<executorch::aten::StridesType> strides(ndim);

if (strides_ptr != nullptr) {
// Use provided strides.
for (int64_t i = 0; i < ndim; i++) {
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
}
} else {
// Calculate strides from sizes.
if (ndim > 0) {
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
1); // Last dimension has stride 1
for (int64_t i = ndim - 2; i >= 0; i--) {
if (sizes_ptr[i + 1] == 0) {
strides[i] = strides[i + 1]; // Copy stride when size is 0
} else {
strides[i] = static_cast<executorch::aten::StridesType>(
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
}
}
}
}
return strides;
}

// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
// Contiguous format means strides decrease from left to right:
// For NCHW: strides = [C*H*W, H*W, W, 1]
inline bool is_contiguous_tensor(
std::vector<executorch::aten::SizesType>& sizes,
std::vector<executorch::aten::StridesType>& strides) {
int64_t ndim = static_cast<int64_t>(strides.size());
int64_t expected_stride = 1;
for (int64_t i = ndim - 1; i >= 0; i--) {
if (strides[i] != expected_stride) {
return false;
}
expected_stride *= sizes[i];
}
return true;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
37 changes: 37 additions & 0 deletions backends/apple/metal/runtime/shims/tensor_attribute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
#include <iostream>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Metal-specific device type constant
__attribute__((__visibility__("default"))) int32_t
aoti_torch_device_type_mps() {
return 13; // Consistent with c10/core/DeviceType.h
}

// Override aoti_torch_get_device_type to return MPS device type
AOTITorchError aoti_torch_get_device_type(
AOTITensorHandle tensor,
int32_t* ret_device_type) {
*ret_device_type = aoti_torch_device_type_mps();
return Error::Ok;
}

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
32 changes: 32 additions & 0 deletions backends/apple/metal/runtime/shims/tensor_attribute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/apple/metal/runtime/shims/types.h>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Metal-specific device type function
int32_t aoti_torch_device_type_mps();

// Override aoti_torch_get_device_type to return MPS device type
AOTITorchError aoti_torch_get_device_type(
AOTITensorHandle tensor,
int32_t* ret_device_type);

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
35 changes: 35 additions & 0 deletions backends/apple/metal/runtime/shims/types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/error.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

// Common using declarations for ExecutorTorch types
using executorch::runtime::Error;
using executorch::runtime::etensor::Tensor;

extern "C" {

// Common AOTI type aliases
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
using AOTITensorHandle = Tensor*;
using AOTIRuntimeError = Error;
using AOTITorchError = Error;
Comment on lines +28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to distinguish both errors? Seems like AOTIRuntimeError is not used

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, AOTInductorModelContainerRun and family return AOTIRuntimeError while aoti_torch_empty_strided and family return AOTITorchError


} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
51 changes: 51 additions & 0 deletions backends/apple/metal/runtime/shims/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/apple/metal/runtime/shims/utils.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

extern "C" {

// Helper function to check if a dtype is supported in Metal backend
bool is_dtype_supported_in_et_metal(int32_t dtype) {
switch (dtype) {
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
return true;
default:
return false;
}
}

// Metal-specific dtype validation utility function
AOTITorchError validate_dtype(int32_t dtype) {
if (is_dtype_supported_in_et_metal(dtype)) {
return Error::Ok;
}

ET_LOG(
Error,
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
dtype,
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
return Error::InvalidArgument;
}

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
46 changes: 46 additions & 0 deletions backends/apple/metal/runtime/shims/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/apple/metal/runtime/shims/types.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace metal {

// Enum for supported data types in et-metal backend
enum class SupportedDTypes : int32_t {
// UINT8 = 0, // PyTorch's uint8 dtype code
// INT8 = 1, // PyTorch's int8 dtype code
// INT16 = 2, // PyTorch's int16 dtype code
// INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
// FLOAT16 = 5, // PyTorch's float16 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
// FLOAT64 = 7, // PyTorch's float64 dtype code
// BOOL = 11, // PyTorch's bool dtype code
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
};

extern "C" {

// Helper function to check if a dtype is supported in Metal backend
bool is_dtype_supported_in_et_metal(int32_t dtype);

// Metal-specific dtype validation utility function
AOTITorchError validate_dtype(int32_t dtype);

} // extern "C"

} // namespace metal
} // namespace backends
} // namespace executorch
2 changes: 2 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using executorch::backends::aoti::aoti_torch_get_device_index;
using executorch::backends::aoti::aoti_torch_get_dtype;
using executorch::backends::aoti::aoti_torch_get_sizes;
using executorch::backends::aoti::aoti_torch_get_strides;
using executorch::backends::aoti::convert_sizes_to_vector;
using executorch::backends::aoti::convert_strides_to_vector;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;
using executorch::backends::aoti::validate_storage_offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cuda_runtime.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/aoti/utils.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
#include <executorch/backends/cuda/runtime/utils.h>
Expand Down
42 changes: 0 additions & 42 deletions backends/cuda/runtime/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,6 @@ enum class SupportedDevices : int32_t {
CUDA = 1, // CUDA device
};

// Utility function to convert sizes pointer to vector
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
int64_t ndim,
const int64_t* sizes_ptr) {
std::vector<executorch::aten::SizesType> sizes(ndim);
for (int i = 0; i < ndim; i++) {
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
}
return sizes;
}

// Utility function to convert strides pointer to vector or calculate from sizes
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr) {
std::vector<executorch::aten::StridesType> strides(ndim);

if (strides_ptr != nullptr) {
// Use provided strides. it is ok if provided strides here is not contiguous
// strides since it will be used internally in CUDA delegate.
for (int64_t i = 0; i < ndim; i++) {
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
}
} else {
// Calculate strides from sizes using ExecutorTorch's algorithm
if (ndim > 0) {
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
1); // Last dimension has stride 1
for (int64_t i = ndim - 2; i >= 0; i--) {
if (sizes_ptr[i + 1] == 0) {
strides[i] = strides[i + 1]; // Copy stride when size is 0
} else {
strides[i] = static_cast<executorch::aten::StridesType>(
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
}
}
}
}
return strides;
}

extern "C" {
using executorch::runtime::Error;
// Common AOTI type aliases
Expand Down
Loading