Skip to content

Commit e73a365

Browse files
Add Metal backend type definitions and utilities
Implement foundational types and utilities for Metal backend including: - AOTI type aliases (AOTITensorHandle, AOTIRuntimeError, AOTITorchError) - Device type handling functions - Tensor storage size queries - Tensor attribute utilities ghstack-source-id: 7bfa3ae ghstack-comment-id: 3392299883 Pull-Request: pytorch#15019
1 parent f06ef08 commit e73a365

File tree

8 files changed

+260
-42
lines changed

8 files changed

+260
-42
lines changed

backends/aoti/utils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,64 @@ inline bool is_tensor_contiguous(
9494

9595
} // extern "C"
9696

97+
// Utility function to convert sizes pointer to vector
98+
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
99+
int64_t ndim,
100+
const int64_t* sizes_ptr) {
101+
std::vector<executorch::aten::SizesType> sizes(ndim);
102+
for (int i = 0; i < ndim; i++) {
103+
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
104+
}
105+
return sizes;
106+
}
107+
108+
// Utility function to convert strides pointer to vector or calculate from sizes
109+
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
110+
int64_t ndim,
111+
const int64_t* sizes_ptr,
112+
const int64_t* strides_ptr) {
113+
std::vector<executorch::aten::StridesType> strides(ndim);
114+
115+
if (strides_ptr != nullptr) {
116+
// Use provided strides.
117+
for (int64_t i = 0; i < ndim; i++) {
118+
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
119+
}
120+
} else {
121+
// Calculate strides from sizes.
122+
if (ndim > 0) {
123+
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
124+
1); // Last dimension has stride 1
125+
for (int64_t i = ndim - 2; i >= 0; i--) {
126+
if (sizes_ptr[i + 1] == 0) {
127+
strides[i] = strides[i + 1]; // Copy stride when size is 0
128+
} else {
129+
strides[i] = static_cast<executorch::aten::StridesType>(
130+
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
131+
}
132+
}
133+
}
134+
}
135+
return strides;
136+
}
137+
138+
// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
139+
// Contiguous format means strides decrease from left to right:
140+
// For NCHW: strides = [C*H*W, H*W, W, 1]
141+
inline bool is_contiguous_tensor(
142+
std::vector<executorch::aten::SizesType>& sizes,
143+
std::vector<executorch::aten::StridesType>& strides) {
144+
int64_t ndim = static_cast<int64_t>(strides.size());
145+
int64_t expected_stride = 1;
146+
for (int64_t i = ndim - 1; i >= 0; i--) {
147+
if (strides[i] != expected_stride) {
148+
return false;
149+
}
150+
expected_stride *= sizes[i];
151+
}
152+
return true;
153+
}
154+
97155
} // namespace aoti
98156
} // namespace backends
99157
} // namespace executorch
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
10+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
11+
#include <iostream>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Metal-specific device type constant
20+
__attribute__((__visibility__("default"))) int32_t
21+
aoti_torch_device_type_mps() {
22+
return 13; // Consistent with c10/core/DeviceType.h
23+
}
24+
25+
// Override aoti_torch_get_device_type to return MPS device type
26+
AOTITorchError aoti_torch_get_device_type(
27+
AOTITensorHandle tensor,
28+
int32_t* ret_device_type) {
29+
*ret_device_type = aoti_torch_device_type_mps();
30+
return Error::Ok;
31+
}
32+
33+
} // extern "C"
34+
35+
} // namespace metal
36+
} // namespace backends
37+
} // namespace executorch
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/common_shims.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace metal {
17+
18+
extern "C" {
19+
20+
// Metal-specific device type function
21+
int32_t aoti_torch_device_type_mps();
22+
23+
// Override aoti_torch_get_device_type to return MPS device type
24+
AOTITorchError aoti_torch_get_device_type(
25+
AOTITensorHandle tensor,
26+
int32_t* ret_device_type);
27+
28+
} // extern "C"
29+
30+
} // namespace metal
31+
} // namespace backends
32+
} // namespace executorch
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/core/error.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace metal {
18+
19+
// Common using declarations for ExecutorTorch types
20+
using executorch::runtime::Error;
21+
using executorch::runtime::etensor::Tensor;
22+
23+
extern "C" {
24+
25+
// Common AOTI type aliases
26+
// Note: AOTITensorHandle is aliased to Tensor* for ExecutorTorch compatibility
27+
using AOTITensorHandle = Tensor*;
28+
using AOTIRuntimeError = Error;
29+
using AOTITorchError = Error;
30+
31+
} // extern "C"
32+
33+
} // namespace metal
34+
} // namespace backends
35+
} // namespace executorch
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace metal {
16+
17+
extern "C" {
18+
19+
// Helper function to check if a dtype is supported in Metal backend
20+
bool is_dtype_supported_in_et_metal(int32_t dtype) {
21+
switch (dtype) {
22+
case static_cast<int32_t>(SupportedDTypes::INT64):
23+
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
24+
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
25+
return true;
26+
default:
27+
return false;
28+
}
29+
}
30+
31+
// Metal-specific dtype validation utility function
32+
AOTITorchError validate_dtype(int32_t dtype) {
33+
if (is_dtype_supported_in_et_metal(dtype)) {
34+
return Error::Ok;
35+
}
36+
37+
ET_LOG(
38+
Error,
39+
"Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)",
40+
dtype,
41+
static_cast<int32_t>(SupportedDTypes::INT64),
42+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
43+
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
44+
return Error::InvalidArgument;
45+
}
46+
47+
} // extern "C"
48+
49+
} // namespace metal
50+
} // namespace backends
51+
} // namespace executorch
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/aoti/utils.h>
12+
#include <executorch/backends/apple/metal/runtime/shims/types.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <cstdint>
15+
16+
namespace executorch {
17+
namespace backends {
18+
namespace metal {
19+
20+
// Enum for supported data types in et-metal backend
21+
enum class SupportedDTypes : int32_t {
22+
// UINT8 = 0, // PyTorch's uint8 dtype code
23+
// INT8 = 1, // PyTorch's int8 dtype code
24+
// INT16 = 2, // PyTorch's int16 dtype code
25+
// INT32 = 3, // PyTorch's int32 dtype code
26+
INT64 = 4, // PyTorch's int64 dtype code
27+
// FLOAT16 = 5, // PyTorch's float16 dtype code
28+
FLOAT32 = 6, // PyTorch's float32 dtype code
29+
// FLOAT64 = 7, // PyTorch's float64 dtype code
30+
// BOOL = 11, // PyTorch's bool dtype code
31+
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
32+
};
33+
34+
extern "C" {
35+
36+
// Helper function to check if a dtype is supported in Metal backend
37+
bool is_dtype_supported_in_et_metal(int32_t dtype);
38+
39+
// Metal-specific dtype validation utility function
40+
AOTITorchError validate_dtype(int32_t dtype);
41+
42+
} // extern "C"
43+
44+
} // namespace metal
45+
} // namespace backends
46+
} // namespace executorch

backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cuda_runtime.h>
1010
#include <executorch/backends/aoti/common_shims.h>
11+
#include <executorch/backends/aoti/utils.h>
1112
#include <executorch/backends/cuda/runtime/shims/memory.h>
1213
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
1314
#include <executorch/backends/cuda/runtime/utils.h>

backends/cuda/runtime/utils.h

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,48 +49,6 @@ enum class SupportedDevices : int32_t {
4949
CUDA = 1, // CUDA device
5050
};
5151

52-
// Utility function to convert sizes pointer to vector
53-
inline std::vector<executorch::aten::SizesType> convert_sizes_to_vector(
54-
int64_t ndim,
55-
const int64_t* sizes_ptr) {
56-
std::vector<executorch::aten::SizesType> sizes(ndim);
57-
for (int i = 0; i < ndim; i++) {
58-
sizes[i] = static_cast<executorch::aten::SizesType>(sizes_ptr[i]);
59-
}
60-
return sizes;
61-
}
62-
63-
// Utility function to convert strides pointer to vector or calculate from sizes
64-
inline std::vector<executorch::aten::StridesType> convert_strides_to_vector(
65-
int64_t ndim,
66-
const int64_t* sizes_ptr,
67-
const int64_t* strides_ptr) {
68-
std::vector<executorch::aten::StridesType> strides(ndim);
69-
70-
if (strides_ptr != nullptr) {
71-
// Use provided strides. it is ok if provided strides here is not contiguous
72-
// strides since it will be used internally in CUDA delegate.
73-
for (int64_t i = 0; i < ndim; i++) {
74-
strides[i] = static_cast<executorch::aten::StridesType>(strides_ptr[i]);
75-
}
76-
} else {
77-
// Calculate strides from sizes using ExecutorTorch's algorithm
78-
if (ndim > 0) {
79-
strides[ndim - 1] = static_cast<executorch::aten::StridesType>(
80-
1); // Last dimension has stride 1
81-
for (int64_t i = ndim - 2; i >= 0; i--) {
82-
if (sizes_ptr[i + 1] == 0) {
83-
strides[i] = strides[i + 1]; // Copy stride when size is 0
84-
} else {
85-
strides[i] = static_cast<executorch::aten::StridesType>(
86-
static_cast<int64_t>(strides[i + 1]) * sizes_ptr[i + 1]);
87-
}
88-
}
89-
}
90-
}
91-
return strides;
92-
}
93-
9452
extern "C" {
9553
using executorch::runtime::Error;
9654
// Common AOTI type aliases

0 commit comments

Comments
 (0)