Skip to content

Commit 11fd4f2

Browse files
lucylqfacebook-github-bot
authored andcommitted
TensorLayout updates (pytorch#7870)
Summary: Addressing comments from D67048723 Differential Revision: D68535453
1 parent 5cbfcdc commit 11fd4f2

File tree

5 files changed

+93
-35
lines changed

5 files changed

+93
-35
lines changed

runtime/core/targets.bzl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def define_common_targets():
4343
"freeable_buffer.h",
4444
"result.h",
4545
"span.h",
46-
"tensor_layout.h",
4746
],
4847
visibility = [
4948
"//executorch/...",
@@ -132,3 +131,14 @@ def define_common_targets():
132131
"//executorch/...",
133132
],
134133
)
134+
135+
runtime.cxx_library(
136+
name = "tensor_layout",
137+
srcs = ["tensor_layout.cpp"],
138+
exported_headers = ["tensor_layout.h"],
139+
exported_deps = [
140+
":core",
141+
"//executorch/runtime/core/exec_aten:lib",
142+
],
143+
visibility = ["//executorch/..."],
144+
)

runtime/core/tensor_layout.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
10+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
11+
#include <executorch/runtime/core/span.h>
12+
#include <executorch/runtime/core/tensor_layout.h>
13+
14+
namespace executorch {
15+
namespace runtime {
16+
17+
namespace {
18+
Result<size_t> calculate_nbytes(
19+
const Span<const int32_t>& sizes,
20+
const exec_aten::ScalarType& scalar_type) {
21+
ssize_t n = 1;
22+
for (ssize_t i = 0; i < sizes.size(); i++) {
23+
if (sizes[i] < 0) {
24+
return Error::InvalidArgument;
25+
}
26+
n *= sizes[i];
27+
}
28+
// Use the full namespace to disambiguate from c10::elementSize.
29+
return n * executorch::runtime::elementSize(scalar_type);
30+
}
31+
} // namespace
32+
33+
Result<TensorLayout> TensorLayout::create(
34+
Span<const int32_t> sizes,
35+
Span<const uint8_t> dim_order,
36+
executorch::aten::ScalarType scalar_type) {
37+
auto nbytes = calculate_nbytes(sizes, scalar_type);
38+
if (!nbytes.ok()) {
39+
return nbytes.error();
40+
}
41+
return TensorLayout(sizes, dim_order, scalar_type, nbytes.get());
42+
}
43+
} // namespace runtime
44+
} // namespace executorch

runtime/core/tensor_layout.h

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,55 @@
1010

1111
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13+
#include <executorch/runtime/core/result.h>
1314
#include <executorch/runtime/core/span.h>
1415

1516
namespace executorch {
1617
namespace runtime {
1718

18-
namespace {
19-
size_t calculate_nbytes(
20-
const Span<const int32_t>& sizes,
21-
const exec_aten::ScalarType& scalar_type) {
22-
ssize_t n = 1;
23-
for (ssize_t i = 0; i < sizes.size(); i++) {
24-
ET_CHECK(sizes[i] >= 0);
25-
n *= sizes[i];
26-
}
27-
// Use the full namespace to disambiguate from c10::elementSize.
28-
return n * executorch::runtime::elementSize(scalar_type);
29-
}
30-
} // namespace
31-
3219
/**
33-
* Metadata describing the layout of external tensors (tensors that are not
34-
stored in the PTE file).
35-
*
36-
* The NamedDataMap used to create the TensorLayout must outlive the
37-
TensorLayout.
20+
* Describes the layout of a tensor.
3821
*/
39-
class TensorLayout {
22+
class ET_EXPERIMENTAL TensorLayout final {
4023
public:
4124
TensorLayout(
42-
executorch::aten::ScalarType scalar_type,
4325
Span<const int32_t> sizes,
44-
Span<const uint8_t> dim_order)
26+
Span<const uint8_t> dim_order,
27+
executorch::aten::ScalarType scalar_type,
28+
size_t nbytes)
4529
: sizes_(sizes),
4630
dim_order_(dim_order),
4731
scalar_type_(scalar_type),
48-
nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
32+
nbytes_(nbytes) {}
4933

50-
TensorLayout(const TensorLayout&) = default;
51-
TensorLayout(TensorLayout&&) = default;
52-
TensorLayout& operator=(const TensorLayout&) = default;
53-
TensorLayout& operator=(TensorLayout&& other) = default;
54-
~TensorLayout() = default;
34+
/** Creates a TensorLayout from the given parameters.
35+
*
36+
* @param[in] sizes The sizes of the tensor. Note: the span passed here must
37+
* outlive the TensorLayout and all copies of it.
38+
* @param[in] dim_order The dim order of the tensor. Note: the span passed
39+
* here must outlive the TensorLayout and all copies of it.
40+
* @param[in] scalar_type The scalar type of the tensor.
41+
* @return A Result containing the TensorLayout on success, or an error.
42+
*/
43+
static executorch::runtime::Result<TensorLayout> create(
44+
Span<const int32_t> sizes,
45+
Span<const uint8_t> dim_order,
46+
executorch::aten::ScalarType scalar_type);
5547

56-
/// Returns the sizes of the tensor.
48+
/**
49+
* Returns the sizes of the tensor.
50+
*
51+
* NOTE: The TensorLayout must outlive the spans returned here.
52+
*/
5753
Span<const int32_t> sizes() const {
5854
return sizes_;
5955
}
6056

61-
/// Returns the dim order of the tensor.
57+
/**
58+
* Returns the dim order of the tensor.
59+
*
60+
* NOTE: The TensorLayout must outlive the spans returned here.
61+
*/
6262
Span<const uint8_t> dim_order() const {
6363
return dim_order_;
6464
}

runtime/core/test/targets.bzl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def define_common_targets():
1919
name = "tensor_layout_test",
2020
srcs = ["tensor_layout_test.cpp"],
2121
deps = [
22-
"//executorch/runtime/core:core",
23-
"//executorch/runtime/core/exec_aten:lib",
22+
"//executorch/runtime/core:tensor_layout",
2423
],
2524
)
2625

runtime/core/test/tensor_layout_test.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
*/
88

99
#include <executorch/runtime/core/exec_aten/exec_aten.h>
10+
#include <executorch/runtime/core/result.h>
1011
#include <executorch/runtime/core/tensor_layout.h>
1112

1213
#include <gtest/gtest.h>
1314

1415
using namespace ::testing;
1516
using executorch::aten::ScalarType;
17+
using executorch::runtime::Result;
1618
using executorch::runtime::Span;
1719
using executorch::runtime::TensorLayout;
1820

@@ -23,9 +25,12 @@ TEST(TestTensorLayout, Ctor) {
2325
Span<const int32_t> sizes_span = {sizes, sizes + 2};
2426
Span<const uint8_t> dim_order_span = {dim_order, dim_order + 2};
2527

26-
TensorLayout layout =
27-
TensorLayout(ScalarType::Float, sizes_span, dim_order_span);
28+
Result<TensorLayout> layout_res =
29+
TensorLayout::create(sizes_span, dim_order_span, ScalarType::Float);
2830

31+
EXPECT_TRUE(layout_res.ok());
32+
33+
TensorLayout layout = layout_res.get();
2934
EXPECT_EQ(layout.scalar_type(), ScalarType::Float);
3035

3136
EXPECT_EQ(layout.sizes().size(), sizes_span.size());

0 commit comments

Comments
 (0)