Skip to content

Commit d6f0bc9

Browse files
Update
[ghstack-poisoned]
1 parent 1a22c5e commit d6f0bc9

File tree

5 files changed

+272
-0
lines changed

5 files changed

+272
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
10+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
11+
#include <iostream>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Metal-specific device type constant
20+
__attribute__((__visibility__("default"))) int32_t
21+
aoti_torch_device_type_mps() {
22+
// Let's use 2 for MPS
23+
return 2;
24+
}
25+
26+
// Override aoti_torch_get_device_type to return MPS device type
27+
AOTITorchError aoti_torch_get_device_type(
28+
AOTITensorHandle tensor,
29+
int32_t* ret_device_type) {
30+
*ret_device_type = aoti_torch_device_type_mps();
31+
return Error::Ok;
32+
}
33+
34+
} // extern "C"
35+
36+
} // namespace metal
37+
} // namespace backends
38+
} // namespace executorch
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/common_shims.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace metal {
17+
18+
extern "C" {
19+
20+
// Metal-specific device type function
21+
int32_t aoti_torch_device_type_mps();
22+
23+
// Override aoti_torch_get_device_type to return MPS device type
24+
AOTITorchError aoti_torch_get_device_type(
25+
AOTITensorHandle tensor,
26+
int32_t* ret_device_type);
27+
28+
} // extern "C"
29+
30+
} // namespace metal
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace metal {
18+
19+
// Common using declarations for ExecutorTorch types
20+
using executorch::runtime::Error;
21+
using executorch::runtime::etensor::Tensor;
22+
23+
extern "C" {
24+
25+
// Common AOTI type aliases
26+
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
27+
using AOTITensorHandle = Tensor*;
28+
using AOTIRuntimeError = Error;
29+
using AOTITorchError = Error;
30+
31+
} // extern "C"
32+
33+
} // namespace metal
34+
} // namespace backends
35+
} // namespace executorch
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Helper function to check if a dtype is supported in Metal backend
20+
bool is_dtype_supported_in_et_metal(int32_t dtype) {
21+
switch (dtype) {
22+
case static_cast<int32_t>(SupportedDTypes::INT64):
23+
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
24+
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
25+
return true;
26+
default:
27+
return false;
28+
}
29+
}
30+
31+
// Metal-specific dtype validation utility function
32+
AOTITorchError validate_dtype(int32_t dtype) {
33+
if (is_dtype_supported_in_et_metal(dtype)) {
34+
return Error::Ok;
35+
}
36+
37+
ET_LOG(
38+
Error,
39+
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
40+
dtype,
41+
static_cast<int32_t>(SupportedDTypes::INT64),
42+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
43+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
44+
return Error::InvalidArgument;
45+
}
46+
47+
} // extern "C"
48+
49+
// Utility function to convert sizes pointer to vector
50+
std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
51+
int64_t ndim,
52+
const int64_t* sizes_ptr) {
53+
std::vector<executorch::aten::SizesType> sizes(ndim);
54+
for (int i = 0; i < ndim; i++) {
55+
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
56+
}
57+
return sizes;
58+
}
59+
60+
// Utility function to convert strides pointer to vector or calculate from sizes
61+
std::vector<executorch::aten::StridesType> convert_strides_to_vector(
62+
int64_t ndim,
63+
const int64_t* sizes_ptr,
64+
const int64_t* strides_ptr) {
65+
std::vector<executorch::aten::StridesType> strides(ndim);
66+
67+
if (strides_ptr != nullptr) {
68+
// Use provided strides. it is ok if provided strides here is not contiguous
69+
// strides since it will be used internally in CUDA delegate.
70+
for (int64_t i = 0; i < ndim; i++) {
71+
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
72+
}
73+
} else {
74+
// Calculate strides from sizes using ExecutorTorch's algorithm
75+
if (ndim > 0) {
76+
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
77+
1); // Last dimension has stride 1
78+
for (int64_t i = ndim - 2; i >= 0; i--) {
79+
if (sizes_ptr[i + 1] == 0) {
80+
strides[i] = strides[i + 1]; // Copy stride when size is 0
81+
} else {
82+
strides[i] = static_cast<executorch::aten::StridesType>(
83+
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
84+
}
85+
}
86+
}
87+
}
88+
return strides;
89+
}
90+
91+
} // namespace metal
92+
} // namespace backends
93+
} // namespace executorch
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/utils.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace metal {
19+
20+
// Enum for supported data types in et-metal backend
21+
enum class SupportedDTypes : int32_t {
22+
// UINT8 = 0, // PyTorch's uint8 dtype code
23+
// INT8 = 1, // PyTorch's int8 dtype code
24+
// INT16 = 2, // PyTorch's int16 dtype code
25+
// INT32 = 3, // PyTorch's int32 dtype code
26+
INT64 = 4, // PyTorch's int64 dtype code
27+
// FLOAT16 = 5, // PyTorch's float16 dtype code
28+
FLOAT32 = 6, // PyTorch's float32 dtype code
29+
// FLOAT64 = 7, // PyTorch's float64 dtype code
30+
// BOOL = 11, // PyTorch's bool dtype code
31+
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
32+
};
33+
34+
extern "C" {
35+
36+
// Helper function to check if a dtype is supported in Metal backend
37+
bool is_dtype_supported_in_et_metal(int32_t dtype);
38+
39+
// Metal-specific dtype validation utility function
40+
AOTITorchError validate_dtype(int32_t dtype);
41+
42+
} // extern "C"
43+
44+
// Utility function to convert sizes pointer to vector
45+
std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
46+
int64_t ndim,
47+
const int64_t* sizes_ptr);
48+
49+
// Utility function to convert strides pointer to vector or calculate from sizes
50+
std::vector<executorch::aten::StridesType> convert_strides_to_vector(
51+
int64_t ndim,
52+
const int64_t* sizes_ptr,
53+
const int64_t* strides_ptr);
54+
55+
// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
56+
// Contiguous format means strides decrease from left to right:
57+
// For NCHW: strides = [C*H*W, H*W, W, 1]
58+
inline bool is_contiguous_tensor(
59+
std::vector<executorch::aten::SizesType> sizes,
60+
std::vector<executorch::aten::StridesType> strides) {
61+
int64_t ndim = static_cast<int64_t>(strides.size());
62+
int64_t expected_stride = 1;
63+
for (int64_t i = ndim - 1; i >= 0; i--) {
64+
if (strides[i] != expected_stride) {
65+
return false;
66+
}
67+
expected_stride *= sizes[i];
68+
}
69+
return true;
70+
}
71+
72+
} // namespace metal
73+
} // namespace backends
74+
} // namespace executorch

0 commit comments

Comments
 (0)