Skip to content

Commit 5c64044

Browse files
lucylqfacebook-github-bot
authored andcommitted
TensorLayout updates (#7870)
Summary: Addressing comments from D67048723 Differential Revision: D68535453
1 parent 73dce90 commit 5c64044

File tree

4 files changed

+71
-28
lines changed

4 files changed

+71
-28
lines changed

runtime/core/targets.bzl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def define_common_targets():
4343
"freeable_buffer.h",
4444
"result.h",
4545
"span.h",
46-
"tensor_layout.h",
4746
],
4847
visibility = [
4948
"//executorch/...",
@@ -132,3 +131,14 @@ def define_common_targets():
132131
"//executorch/...",
133132
],
134133
)
134+
135+
runtime.cxx_library(
136+
name = "tensor_layout",
137+
srcs = ["tensor_layout.cpp"],
138+
exported_headers = ["tensor_layout.h"],
139+
exported_deps = [
140+
":core",
141+
"//executorch/runtime/core/exec_aten:lib",
142+
],
143+
visibility = ["//executorch/..."],
144+
)

runtime/core/tensor_layout.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
10+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
11+
#include <executorch/runtime/core/span.h>
12+
#include <executorch/runtime/core/tensor_layout.h>
13+
14+
namespace executorch {
15+
namespace runtime {
16+
17+
namespace {
18+
size_t calculate_nbytes(
19+
const Span<const int32_t>& sizes,
20+
const exec_aten::ScalarType& scalar_type) {
21+
ssize_t n = 1;
22+
for (ssize_t i = 0; i < sizes.size(); i++) {
23+
ET_CHECK(sizes[i] >= 0);
24+
n *= sizes[i];
25+
}
26+
// Use the full namespace to disambiguate from c10::elementSize.
27+
return n * executorch::runtime::elementSize(scalar_type);
28+
}
29+
} // namespace
30+
31+
/**
32+
* NOTE: the spans (sizes, dim_order) passed here must outlive the
33+
* TensorLayout and all copies of it.
34+
*/
35+
TensorLayout::TensorLayout(
36+
executorch::aten::ScalarType scalar_type,
37+
Span<const int32_t> sizes,
38+
Span<const uint8_t> dim_order)
39+
: sizes_(sizes),
40+
dim_order_(dim_order),
41+
scalar_type_(scalar_type),
42+
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
43+
44+
} // namespace runtime
45+
} // namespace executorch

runtime/core/tensor_layout.h

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,19 @@
1515
namespace executorch {
1616
namespace runtime {
1717

18-
namespace {
19-
size_t calculate_nbytes(
20-
const Span<const int32_t>& sizes,
21-
const exec_aten::ScalarType& scalar_type) {
22-
ssize_t n = 1;
23-
for (ssize_t i = 0; i < sizes.size(); i++) {
24-
ET_CHECK(sizes[i] >= 0);
25-
n *= sizes[i];
26-
}
27-
// Use the full namespace to disambiguate from c10::elementSize.
28-
return n * executorch::runtime::elementSize(scalar_type);
29-
}
30-
} // namespace
31-
3218
/**
3319
* Metadata describing the layout of external tensors (tensors that are not
3420
stored in the PTE file).
3521
*
36-
* The NamedDataMap used to create the TensorLayout must outlive the
37-
TensorLayout.
22+
* NOTE: the underlying data structure/s backing the TensorLayout must outlive
23+
* the TensorLayout and all copies of it.
3824
*/
39-
class TensorLayout {
25+
class ET_EXPERIMENTAL TensorLayout final {
4026
public:
4127
TensorLayout(
4228
executorch::aten::ScalarType scalar_type,
4329
Span<const int32_t> sizes,
44-
Span<const uint8_t> dim_order)
45-
: sizes_(sizes),
46-
dim_order_(dim_order),
47-
scalar_type_(scalar_type),
48-
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
49-
30+
Span<const uint8_t> dim_order);
5031
TensorLayout(const TensorLayout&) = default;
5132
TensorLayout(TensorLayout&&) = default;
5233
TensorLayout& operator=(const TensorLayout&) = default;
@@ -74,10 +55,18 @@ class TensorLayout {
7455
}
7556

7657
private:
77-
/// The sizes of the tensor.
58+
/**
59+
* The sizes of the tensor.
60+
*
61+
* NOTE: The TensorLayout must outlive the spans returned here.
62+
*/
7863
Span<const int32_t> sizes_;
7964

80-
/// The dim order of the tensor.
65+
/**
66+
* The dim order of the tensor.
67+
*
68+
* NOTE: The TensorLayout must outlive the spans returned here.
69+
*/
8170
Span<const uint8_t> dim_order_;
8271

8372
/// The scalar type of the tensor.

runtime/core/test/targets.bzl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def define_common_targets():
1919
name = "tensor_layout_test",
2020
srcs = ["tensor_layout_test.cpp"],
2121
deps = [
22-
"//executorch/runtime/core:core",
23-
"//executorch/runtime/core/exec_aten:lib",
22+
"//executorch/runtime/core:tensor_layout",
2423
],
2524
)
2625

0 commit comments

Comments
 (0)