Skip to content

Commit ecb38df

Browse files
authored
[slimtensor] Introduce Device and ScalarType headers for SlimTensor minimal support
Differential Revision: D89747061 Pull Request resolved: #16382
1 parent b56eaa5 commit ecb38df

File tree

9 files changed

+549
-0
lines changed

9 files changed

+549
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <functional>
13+
#include <string>
14+
15+
#include <executorch/backends/aoti/slim/c10/core/DeviceType.h>
16+
#include <executorch/runtime/platform/assert.h>
17+
18+
namespace executorch::backends::aoti::slim::c10 {
19+
20+
/// An index representing a specific device; e.g., the 1 in GPU 1.
21+
/// A DeviceIndex is not independently meaningful without knowing
22+
/// the DeviceType it is associated; try to use Device rather than
23+
/// DeviceIndex directly.
24+
using DeviceIndex = int8_t;
25+
26+
/// Represents a compute device on which a tensor is located.
27+
/// A device is uniquely identified by a type (e.g., CPU) and a device index.
28+
struct Device final {
29+
/// Constructs a new Device from a DeviceType and an optional device index.
30+
/// @param type The type of device.
31+
/// @param index The device index. For CPU, this should be -1 or 0.
32+
/* implicit */
33+
explicit Device(DeviceType type, DeviceIndex index = -1)
34+
: type_(type), index_(index) {
35+
validate();
36+
}
37+
38+
/// Constructs a Device from a string description.
39+
/// The string must be "cpu" or "cpu:0".
40+
/* implicit */ Device(const std::string& device_string)
41+
: Device(DeviceType::CPU) {
42+
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");
43+
44+
if (device_string == "cpu" || device_string == "CPU") {
45+
type_ = DeviceType::CPU;
46+
index_ = -1;
47+
} else if (device_string == "cpu:0" || device_string == "CPU:0") {
48+
type_ = DeviceType::CPU;
49+
index_ = static_cast<DeviceIndex>(device_string.back() - '0');
50+
} else {
51+
ET_CHECK_MSG(
52+
false,
53+
"Invalid device string: %s. Currently only 'cpu' is supported.",
54+
device_string.c_str());
55+
}
56+
validate();
57+
}
58+
59+
/// Returns true if the type and index of this Device matches that of other.
60+
bool operator==(const Device& other) const noexcept {
61+
return this->type_ == other.type_ && this->index_ == other.index_;
62+
}
63+
64+
/// Returns true if the type or index of this Device differs from that of
65+
/// other.
66+
bool operator!=(const Device& other) const noexcept {
67+
return !(*this == other);
68+
}
69+
70+
/// Sets the device index.
71+
void set_index(DeviceIndex index) {
72+
index_ = index;
73+
}
74+
75+
/// Returns the type of device this is.
76+
DeviceType type() const noexcept {
77+
return type_;
78+
}
79+
80+
/// Returns the device index.
81+
DeviceIndex index() const noexcept {
82+
return index_;
83+
}
84+
85+
/// Returns true if the device has a non-default index.
86+
bool has_index() const noexcept {
87+
return index_ != -1;
88+
}
89+
90+
/// Returns true if the device is of CPU type.
91+
bool is_cpu() const noexcept {
92+
return type_ == DeviceType::CPU;
93+
}
94+
95+
/// Returns a string representation of the device (e.g., "cpu" or "cpu:0").
96+
std::string str() const {
97+
std::string str = DeviceTypeName(type(), /* lower_case */ true);
98+
if (has_index()) {
99+
str.push_back(':');
100+
str.append(std::to_string(index()));
101+
}
102+
return str;
103+
}
104+
105+
private:
106+
DeviceType type_;
107+
DeviceIndex index_ = -1;
108+
109+
void validate() {
110+
ET_DCHECK_MSG(
111+
index_ >= -1,
112+
"Device index must be -1 or non-negative, got %d",
113+
static_cast<int>(index_));
114+
ET_DCHECK_MSG(
115+
!is_cpu() || index_ <= 0,
116+
"CPU device index must be -1 or zero, got %d",
117+
static_cast<int>(index_));
118+
}
119+
};
120+
121+
inline std::ostream& operator<<(std::ostream& stream, const Device& device) {
122+
stream << device.str();
123+
return stream;
124+
}
125+
126+
} // namespace executorch::backends::aoti::slim::c10
127+
128+
namespace std {
129+
template <>
130+
struct hash<executorch::backends::aoti::slim::c10::Device> {
131+
size_t operator()(
132+
executorch::backends::aoti::slim::c10::Device d) const noexcept {
133+
static_assert(
134+
sizeof(executorch::backends::aoti::slim::c10::DeviceType) == 1,
135+
"DeviceType is not 8-bit");
136+
static_assert(
137+
sizeof(executorch::backends::aoti::slim::c10::DeviceIndex) == 1,
138+
"DeviceIndex is not 8-bit");
139+
uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
140+
<< 16 |
141+
static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
142+
return std::hash<uint32_t>{}(bits);
143+
}
144+
};
145+
} // namespace std
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
#include <ostream>
13+
#include <string>
14+
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace executorch::backends::aoti::slim::c10 {
18+
19+
/// Enum representing the type of device.
20+
enum class DeviceType : int8_t {
21+
CPU = 0,
22+
COMPILE_TIME_MAX_DEVICE_TYPES = 1,
23+
};
24+
25+
constexpr DeviceType kCPU = DeviceType::CPU;
26+
27+
/// Maximum number of device types at compile time.
28+
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
29+
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
30+
31+
/// Returns the name of the device type as a string.
32+
/// @param d The device type.
33+
/// @param lower_case If true, returns the name in lower case.
34+
/// @return The name of the device type.
35+
inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
36+
switch (d) {
37+
case DeviceType::CPU:
38+
return lower_case ? "cpu" : "CPU";
39+
default:
40+
ET_CHECK_MSG(false, "Unknown device type: %d", static_cast<int>(d));
41+
}
42+
}
43+
44+
/// Checks if the device type is valid.
45+
/// @param d The device type to check.
46+
/// @return true if the device type is valid, false otherwise.
47+
inline bool isValidDeviceType(DeviceType d) {
48+
return d == DeviceType::CPU;
49+
}
50+
51+
inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
52+
stream << DeviceTypeName(type, /* lower_case */ true);
53+
return stream;
54+
}
55+
56+
} // namespace executorch::backends::aoti::slim::c10
57+
58+
namespace std {
59+
template <>
60+
struct hash<executorch::backends::aoti::slim::c10::DeviceType> {
61+
std::size_t operator()(
62+
executorch::backends::aoti::slim::c10::DeviceType k) const {
63+
return std::hash<int>()(static_cast<int>(k));
64+
}
65+
};
66+
} // namespace std
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstddef>
12+
#include <cstdint>
13+
#include <ostream>
14+
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace executorch::backends::aoti::slim::c10 {
18+
19+
/// Enum representing the scalar type (dtype) of tensor elements.
20+
/// Note: Enum values must match PyTorch's c10::ScalarType for compatibility.
21+
enum class ScalarType : int8_t {
22+
// Byte = 0,
23+
// Char = 1,
24+
// Short = 2,
25+
// Int = 3,
26+
// Long = 4,
27+
Float = 6,
28+
// Bool = 11,
29+
// BFloat16 = 15,
30+
Undefined = -1,
31+
NumOptions = 7,
32+
};
33+
34+
/// Constant for Float scalar type.
35+
constexpr ScalarType kFloat = ScalarType::Float;
36+
37+
/// Returns the size in bytes of a single element of the given scalar type.
38+
/// @param t The scalar type.
39+
/// @return The size in bytes of a single element.
40+
inline size_t elementSize(ScalarType t) {
41+
switch (t) {
42+
case ScalarType::Float:
43+
return sizeof(float);
44+
default:
45+
ET_CHECK_MSG(false, "Unknown ScalarType: %d", static_cast<int>(t));
46+
}
47+
}
48+
49+
/// Returns the name of the scalar type as a string.
50+
/// @param t The scalar type.
51+
/// @return The name of the scalar type.
52+
inline const char* toString(ScalarType t) {
53+
switch (t) {
54+
case ScalarType::Float:
55+
return "Float";
56+
case ScalarType::Undefined:
57+
return "Undefined";
58+
default:
59+
return "UNKNOWN_SCALAR";
60+
}
61+
}
62+
63+
/// Checks if the scalar type is a floating point type.
64+
/// @param t The scalar type to check.
65+
/// @return true if the scalar type is floating point, false otherwise.
66+
inline bool isFloatingType(ScalarType t) {
67+
return t == ScalarType::Float;
68+
}
69+
70+
/// Checks if the scalar type is an integral type (including bool).
71+
/// @param t The scalar type to check.
72+
/// @param includeBool Whether to consider Bool as integral.
73+
/// @return true if the scalar type is integral, false otherwise.
74+
inline bool isIntegralType(ScalarType t, bool /*includeBool*/) {
75+
(void)t;
76+
return false;
77+
}
78+
79+
inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) {
80+
return stream << toString(scalar_type);
81+
}
82+
83+
} // namespace executorch::backends::aoti::slim::c10
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Define targets for SlimTensor c10 core module."""
5+
6+
# Header-only library for DeviceType
7+
runtime.cxx_library(
8+
name = "device_type",
9+
headers = [
10+
"DeviceType.h",
11+
],
12+
visibility = ["@EXECUTORCH_CLIENTS"],
13+
exported_deps = [
14+
"//executorch/runtime/platform:platform",
15+
],
16+
)
17+
18+
# Header-only library for Device
19+
runtime.cxx_library(
20+
name = "device",
21+
headers = [
22+
"Device.h",
23+
],
24+
visibility = ["@EXECUTORCH_CLIENTS"],
25+
exported_deps = [
26+
":device_type",
27+
"//executorch/runtime/platform:platform",
28+
],
29+
)
30+
31+
# Header-only library for ScalarType
32+
runtime.cxx_library(
33+
name = "scalar_type",
34+
headers = [
35+
"ScalarType.h",
36+
],
37+
visibility = ["@EXECUTORCH_CLIENTS"],
38+
exported_deps = [
39+
"//executorch/runtime/platform:platform",
40+
],
41+
)
42+
43+
# Combined c10 core library
44+
runtime.cxx_library(
45+
name = "core",
46+
visibility = ["@EXECUTORCH_CLIENTS"],
47+
exported_deps = [
48+
":device",
49+
":device_type",
50+
":scalar_type",
51+
],
52+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
define_common_targets()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Define test targets for SlimTensor c10 core module."""
5+
6+
runtime.cxx_test(
7+
name = "test_device",
8+
srcs = [
9+
"test_device.cpp",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti/slim/c10/core:device",
13+
"//executorch/backends/aoti/slim/c10/core:device_type",
14+
],
15+
)
16+
17+
runtime.cxx_test(
18+
name = "test_scalar_type",
19+
srcs = [
20+
"test_scalar_type.cpp",
21+
],
22+
deps = [
23+
"//executorch/backends/aoti/slim/c10/core:scalar_type",
24+
],
25+
)

0 commit comments

Comments
 (0)