Skip to content

Commit 1cfceb1

Browse files
committed
[slimtensor] Introduce Device and ScalarType headers for SlimTensor minimal support
Pull Request resolved: #16382 This diff introduces the foundational c10 core headers for SlimTensor, a lightweight tensor implementation used by torchnative, to cuda backend runtime and further it will be used by all aoti-driven backends like MPS. We add: - DeviceType.h - Device type enum (CPU only for now) - Device.h - Device class representing compute device location - ScalarType.h - Scalar type enum with elementSize() helper (Float only for now) These headers are modeled after PyTorch's c10 but simplified for our needs. The enum values are kept compatible with PyTorch for serialization compatibility. This is the first step in migrating SlimTensor to replace ETensor as the internal tensor representation in CUDA backend. Future diffs will add Storage, SlimTensor class, and additional dtypes/devices incrementally. ghstack-source-id: 332731269 @exported-using-ghexport Differential Revision: [D89747061](https://our.internmc.facebook.com/intern/diff/D89747061/)
1 parent 41491c0 commit 1cfceb1

File tree

9 files changed

+549
-0
lines changed

9 files changed

+549
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <functional>
13+
#include <string>
14+
15+
#include <executorch/backends/aoti/slim/c10/core/DeviceType.h>
16+
#include <executorch/runtime/platform/assert.h>
17+
18+
namespace executorch::backends::aoti::slim::c10 {
19+
20+
/// An index representing a specific device; e.g., the 1 in GPU 1.
21+
/// A DeviceIndex is not independently meaningful without knowing
22+
/// the DeviceType it is associated; try to use Device rather than
23+
/// DeviceIndex directly.
24+
using DeviceIndex = int8_t;
25+
26+
/// Represents a compute device on which a tensor is located.
27+
/// A device is uniquely identified by a type (e.g., CPU) and a device index.
28+
struct Device final {
29+
/// Constructs a new Device from a DeviceType and an optional device index.
30+
/// @param type The type of device.
31+
/// @param index The device index. For CPU, this should be -1 or 0.
32+
/* implicit */
33+
explicit Device(DeviceType type, DeviceIndex index = -1)
34+
: type_(type), index_(index) {
35+
validate();
36+
}
37+
38+
/// Constructs a Device from a string description.
39+
/// The string must be "cpu" or "cpu:0".
40+
/* implicit */ Device(const std::string& device_string)
41+
: Device(DeviceType::CPU) {
42+
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");
43+
44+
if (device_string == "cpu" || device_string == "CPU") {
45+
type_ = DeviceType::CPU;
46+
index_ = -1;
47+
} else if (device_string == "cpu:0" || device_string == "CPU:0") {
48+
type_ = DeviceType::CPU;
49+
index_ = static_cast<DeviceIndex>(device_string.back() - '0');
50+
} else {
51+
ET_CHECK_MSG(
52+
false,
53+
"Invalid device string: %s. Currently only 'cpu' is supported.",
54+
device_string.c_str());
55+
}
56+
validate();
57+
}
58+
59+
/// Returns true if the type and index of this Device matches that of other.
60+
bool operator==(const Device& other) const noexcept {
61+
return this->type_ == other.type_ && this->index_ == other.index_;
62+
}
63+
64+
/// Returns true if the type or index of this Device differs from that of
65+
/// other.
66+
bool operator!=(const Device& other) const noexcept {
67+
return !(*this == other);
68+
}
69+
70+
/// Sets the device index.
71+
void set_index(DeviceIndex index) {
72+
index_ = index;
73+
}
74+
75+
/// Returns the type of device this is.
76+
DeviceType type() const noexcept {
77+
return type_;
78+
}
79+
80+
/// Returns the device index.
81+
DeviceIndex index() const noexcept {
82+
return index_;
83+
}
84+
85+
/// Returns true if the device has a non-default index.
86+
bool has_index() const noexcept {
87+
return index_ != -1;
88+
}
89+
90+
/// Returns true if the device is of CPU type.
91+
bool is_cpu() const noexcept {
92+
return type_ == DeviceType::CPU;
93+
}
94+
95+
/// Returns a string representation of the device (e.g., "cpu" or "cpu:0").
96+
std::string str() const {
97+
std::string str = DeviceTypeName(type(), /* lower_case */ true);
98+
if (has_index()) {
99+
str.push_back(':');
100+
str.append(std::to_string(index()));
101+
}
102+
return str;
103+
}
104+
105+
private:
106+
DeviceType type_;
107+
DeviceIndex index_ = -1;
108+
109+
void validate() {
110+
ET_DCHECK_MSG(
111+
index_ >= -1,
112+
"Device index must be -1 or non-negative, got %d",
113+
static_cast<int>(index_));
114+
ET_DCHECK_MSG(
115+
!is_cpu() || index_ <= 0,
116+
"CPU device index must be -1 or zero, got %d",
117+
static_cast<int>(index_));
118+
}
119+
};
120+
121+
inline std::ostream& operator<<(std::ostream& stream, const Device& device) {
122+
stream << device.str();
123+
return stream;
124+
}
125+
126+
} // namespace executorch::backends::aoti::slim::c10
127+
128+
namespace std {
129+
template <>
130+
struct hash<executorch::backends::aoti::slim::c10::Device> {
131+
size_t operator()(
132+
executorch::backends::aoti::slim::c10::Device d) const noexcept {
133+
static_assert(
134+
sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1,
135+
"DeviceType is not 8-bit");
136+
static_assert(
137+
sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1,
138+
"DeviceIndex is not 8-bit");
139+
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
140+
<< 16 |
141+
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
142+
return std::hash<uint32_t>{}(bits);
143+
}
144+
};
145+
} // namespace std
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <ostream>
13+
#include <string>
14+
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace executorch::backends::aoti::slim::c10 {
18+
19+
/// Enum representing the type of device.
20+
enum class DeviceType : int8_t {
21+
CPU = 0,
22+
COMPILE_TIME_MAX_DEVICE_TYPES = 1,
23+
};
24+
25+
constexpr DeviceType kCPU = DeviceType::CPU;
26+
27+
/// Maximum number of device types at compile time.
28+
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
29+
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
30+
31+
/// Returns the name of the device type as a string.
32+
/// @param d The device type.
33+
/// @param lower_case If true, returns the name in lower case.
34+
/// @return The name of the device type.
35+
inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
36+
switch (d) {
37+
case DeviceType::CPU:
38+
return lower_case ? "cpu" : "CPU";
39+
default:
40+
ET_CHECK_MSG(false, "Unknown device type: %d", static_cast<int>(d));
41+
}
42+
}
43+
44+
/// Checks if the device type is valid.
45+
/// @param d The device type to check.
46+
/// @return true if the device type is valid, false otherwise.
47+
inline bool isValidDeviceType(DeviceType d) {
48+
return d == DeviceType::CPU;
49+
}
50+
51+
inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
52+
stream << DeviceTypeName(type, /* lower_case */ true);
53+
return stream;
54+
}
55+
56+
} // namespace executorch::backends::aoti::slim::c10
57+
58+
namespace std {
59+
template <>
60+
struct hash<executorch::backends::aoti::slim::c10::DeviceType> {
61+
std::size_t operator()(
62+
executorch::backends::aoti::slim::c10::DeviceType k) const {
63+
return std::hash<int>()(static_cast<int>(k));
64+
}
65+
};
66+
} // namespace std
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstddef>
12+
#include <cstdint>
13+
#include <ostream>
14+
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace executorch::backends::aoti::slim::c10 {
18+
19+
/// Enum representing the scalar type (dtype) of tensor elements.
20+
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
21+
enum class ScalarType : int8_t {
22+
// Byte = 0,
23+
// Char = 1,
24+
// Short = 2,
25+
// Int = 3,
26+
// Long = 4,
27+
Float = 6,
28+
// Bool = 11,
29+
// BFloat16 = 15,
30+
Undefined = -1,
31+
NumOptions = 7,
32+
};
33+
34+
/// Constant for Float scalar type.
35+
constexpr ScalarType kFloat = ScalarType::Float;
36+
37+
/// Returns the size in bytes of a single element of the given scalar type.
38+
/// @param t The scalar type.
39+
/// @return The size in bytes of a single element.
40+
inline size_t elementSize(ScalarType t) {
41+
switch (t) {
42+
case ScalarType::Float:
43+
return sizeof(float);
44+
default:
45+
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
46+
}
47+
}
48+
49+
/// Returns the name of the scalar type as a string.
50+
/// @param t The scalar type.
51+
/// @return The name of the scalar type.
52+
inline const char* toString(ScalarType t) {
53+
switch (t) {
54+
case ScalarType::Float:
55+
return "Float";
56+
case ScalarType::Undefined:
57+
return "Undefined";
58+
default:
59+
return "UNKNOWN_SCALAR";
60+
}
61+
}
62+
63+
/// Checks if the scalar type is a floating point type.
64+
/// @param t The scalar type to check.
65+
/// @return true if the scalar type is floating point, false otherwise.
66+
inline bool isFloatingType(ScalarType t) {
67+
return t == ScalarType::Float;
68+
}
69+
70+
/// Checks if the scalar type is an integral type (including bool).
71+
/// @param t The scalar type to check.
72+
/// @param includeBool Whether to consider Bool as integral.
73+
/// @return true if the scalar type is integral, false otherwise.
74+
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
75+
(void)t;
76+
return false;
77+
}
78+
79+
inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {
80+
return stream << toString(scalar_type);
81+
}
82+
83+
} // namespace executorch::backends::aoti::slim::c10
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Define targets for SlimTensor c10 core module."""
5+
6+
# Header-only library for DeviceType
7+
runtime.cxx_library(
8+
name = "device_type",
9+
headers = [
10+
"DeviceType.h",
11+
],
12+
visibility = ["@EXECUTORCH_CLIENTS"],
13+
exported_deps = [
14+
"//executorch/runtime/platform:platform",
15+
],
16+
)
17+
18+
# Header-only library for Device
19+
runtime.cxx_library(
20+
name = "device",
21+
headers = [
22+
"Device.h",
23+
],
24+
visibility = ["@EXECUTORCH_CLIENTS"],
25+
exported_deps = [
26+
":device_type",
27+
"//executorch/runtime/platform:platform",
28+
],
29+
)
30+
31+
# Header-only library for ScalarType
32+
runtime.cxx_library(
33+
name = "scalar_type",
34+
headers = [
35+
"ScalarType.h",
36+
],
37+
visibility = ["@EXECUTORCH_CLIENTS"],
38+
exported_deps = [
39+
"//executorch/runtime/platform:platform",
40+
],
41+
)
42+
43+
# Combined c10 core library
44+
runtime.cxx_library(
45+
name = "core",
46+
visibility = ["@EXECUTORCH_CLIENTS"],
47+
exported_deps = [
48+
":device",
49+
":device_type",
50+
":scalar_type",
51+
],
52+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Define test targets for SlimTensor c10 core module."""
5+
6+
runtime.cxx_test(
7+
name = "test_device",
8+
srcs = [
9+
"test_device.cpp",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti/slim/c10/core:device",
13+
"//executorch/backends/aoti/slim/c10/core:device_type",
14+
],
15+
)
16+
17+
runtime.cxx_test(
18+
name = "test_scalar_type",
19+
srcs = [
20+
"test_scalar_type.cpp",
21+
],
22+
deps = [
23+
"//executorch/backends/aoti/slim/c10/core:scalar_type",
24+
],
25+
)

0 commit comments

Comments
 (0)