Skip to content

Commit 5942e4a

Browse files
authored
Allow EValue to be constructed with a smart pointer implicitly.
Differential Revision: D61783902 Pull Request resolved: #4902
1 parent da2142b commit 5942e4a

File tree

2 files changed

+130
-19
lines changed

2 files changed

+130
-19
lines changed

runtime/core/evalue.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,23 @@ struct EValue {
238238
new (&payload.as_tensor) exec_aten::Tensor(t);
239239
}
240240

241+
// Template constructor that allows construction from types that can be
242+
// dereferenced to produce a type that EValue can be implicitly constructed
243+
// from.
244+
template <typename T>
245+
/*implicit*/ EValue(
246+
T&& value,
247+
typename std::enable_if<std::is_convertible<
248+
decltype(*std::forward<T>(value)),
249+
EValue>::value>::type* = 0) {
250+
ET_CHECK_MSG(value != nullptr, "Pointer is null.");
251+
*this = EValue(*std::forward<T>(value));
252+
}
253+
254+
// Delete constructor for raw pointers to ensure they cannot be used.
255+
template <typename T>
256+
explicit EValue(T* value) = delete;
257+
241258
bool isTensor() const {
242259
return tag == Tag::Tensor;
243260
}

runtime/core/test/evalue_test.cpp

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,67 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/runtime/core/evalue.h>
10+
911
#include <gtest/gtest.h>
1012

11-
#include <executorch/runtime/core/evalue.h>
12-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1313
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <executorch/runtime/platform/runtime.h>
1415
#include <executorch/test/utils/DeathTest.h>
1516

1617
using namespace ::testing;
18+
19+
namespace torch {
20+
namespace executor {
21+
1722
using exec_aten::ScalarType;
1823
using executorch::runtime::BoxedEvalueList;
1924
using executorch::runtime::EValue;
2025
using executorch::runtime::Tag;
2126
using executorch::runtime::testing::TensorFactory;
2227

23-
TEST(TestEValue, CopyTrivialType) {
28+
class EValueTest : public ::testing::Test {
29+
protected:
30+
void SetUp() override {
31+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
32+
// first.
33+
runtime_init();
34+
}
35+
};
36+
37+
// An utility class used in tests to simulate objects that manage Tensors.
38+
// The overloaded operator*() is used to return the underlying Tensor, mimicking
39+
// behavior of smart pointers.
40+
class TensorWrapper {
41+
public:
42+
explicit TensorWrapper(exec_aten::Tensor tensor)
43+
: tensor_(std::make_unique<exec_aten::Tensor>(std::move(tensor))) {}
44+
45+
exec_aten::Tensor& operator*() {
46+
return *tensor_;
47+
}
48+
49+
const exec_aten::Tensor& operator*() const {
50+
return *tensor_;
51+
}
52+
53+
operator bool() const {
54+
return static_cast<bool>(tensor_);
55+
}
56+
57+
bool operator==(std::nullptr_t) const {
58+
return tensor_ == nullptr;
59+
}
60+
61+
bool operator!=(std::nullptr_t) const {
62+
return tensor_ != nullptr;
63+
}
64+
65+
private:
66+
std::unique_ptr<exec_aten::Tensor> tensor_;
67+
};
68+
69+
TEST_F(EValueTest, CopyTrivialType) {
2470
EValue a;
2571
EValue b(true);
2672
EXPECT_TRUE(a.isNone());
@@ -30,7 +76,7 @@ TEST(TestEValue, CopyTrivialType) {
3076
EXPECT_EQ(b.to<bool>(), true);
3177
}
3278

33-
TEST(TestEValue, CopyTensor) {
79+
TEST_F(EValueTest, CopyTensor) {
3480
TensorFactory<ScalarType::Float> tf;
3581
EValue a(tf.ones({3, 2}));
3682
EValue b(tf.ones({1}));
@@ -39,7 +85,7 @@ TEST(TestEValue, CopyTensor) {
3985
EXPECT_EQ(a.toTensor().dim(), 1);
4086
}
4187

42-
TEST(TestEValue, TypeMismatchFatals) {
88+
TEST_F(EValueTest, TypeMismatchFatals) {
4389
ET_EXPECT_DEATH(
4490
{
4591
auto e = EValue(true);
@@ -48,12 +94,12 @@ TEST(TestEValue, TypeMismatchFatals) {
4894
"");
4995
}
5096

51-
TEST(TestEValue, NoneByDefault) {
97+
TEST_F(EValueTest, NoneByDefault) {
5298
EValue e;
5399
EXPECT_TRUE(e.isNone());
54100
}
55101

56-
TEST(TestEValue, ToOptionalInt) {
102+
TEST_F(EValueTest, ToOptionalInt) {
57103
EValue e((int64_t)5);
58104
EXPECT_TRUE(e.isInt());
59105
EXPECT_FALSE(e.isNone());
@@ -63,15 +109,15 @@ TEST(TestEValue, ToOptionalInt) {
63109
EXPECT_EQ(o.value(), 5);
64110
}
65111

66-
TEST(TestEValue, NoneToOptionalInt) {
112+
TEST_F(EValueTest, NoneToOptionalInt) {
67113
EValue e;
68114
EXPECT_TRUE(e.isNone());
69115

70116
exec_aten::optional<int64_t> o = e.toOptional<int64_t>();
71117
EXPECT_FALSE(o.has_value());
72118
}
73119

74-
TEST(TestEValue, ToOptionalScalar) {
120+
TEST_F(EValueTest, ToOptionalScalar) {
75121
exec_aten::Scalar s((double)3.141);
76122
EValue e(s);
77123
EXPECT_TRUE(e.isScalar());
@@ -83,7 +129,7 @@ TEST(TestEValue, ToOptionalScalar) {
83129
EXPECT_EQ(o.value().to<double>(), 3.141);
84130
}
85131

86-
TEST(TESTEValue, ScalarToType) {
132+
TEST_F(EValueTest, ScalarToType) {
87133
exec_aten::Scalar s_d((double)3.141);
88134
EXPECT_EQ(s_d.to<double>(), 3.141);
89135
exec_aten::Scalar s_i((int64_t)3);
@@ -92,23 +138,23 @@ TEST(TESTEValue, ScalarToType) {
92138
EXPECT_EQ(s_b.to<bool>(), true);
93139
}
94140

95-
TEST(TestEValue, NoneToOptionalScalar) {
141+
TEST_F(EValueTest, NoneToOptionalScalar) {
96142
EValue e;
97143
EXPECT_TRUE(e.isNone());
98144

99145
exec_aten::optional<exec_aten::Scalar> o = e.toOptional<exec_aten::Scalar>();
100146
EXPECT_FALSE(o.has_value());
101147
}
102148

103-
TEST(TestEValue, NoneToOptionalTensor) {
149+
TEST_F(EValueTest, NoneToOptionalTensor) {
104150
EValue e;
105151
EXPECT_TRUE(e.isNone());
106152

107153
exec_aten::optional<exec_aten::Tensor> o = e.toOptional<exec_aten::Tensor>();
108154
EXPECT_FALSE(o.has_value());
109155
}
110156

111-
TEST(TestEValue, ToScalarType) {
157+
TEST_F(EValueTest, ToScalarType) {
112158
EValue e((int64_t)4);
113159
auto o = e.toScalarType();
114160
EXPECT_EQ(o, exec_aten::ScalarType::Long);
@@ -118,7 +164,7 @@ TEST(TestEValue, ToScalarType) {
118164
EXPECT_EQ(o2.value(), exec_aten::ScalarType::Long);
119165
}
120166

121-
TEST(TestEValue, toString) {
167+
TEST_F(EValueTest, toString) {
122168
const EValue e("foo", 3);
123169
EXPECT_TRUE(e.isString());
124170
EXPECT_FALSE(e.isNone());
@@ -127,28 +173,28 @@ TEST(TestEValue, toString) {
127173
EXPECT_EQ(x, "foo");
128174
}
129175

130-
TEST(TestEValue, MemoryFormat) {
176+
TEST_F(EValueTest, MemoryFormat) {
131177
const EValue e((int64_t)0);
132178
EXPECT_TRUE(e.isInt());
133179
const exec_aten::MemoryFormat m = e.to<exec_aten::MemoryFormat>();
134180
EXPECT_EQ(m, exec_aten::MemoryFormat::Contiguous);
135181
}
136182

137-
TEST(TestEValue, Layout) {
183+
TEST_F(EValueTest, Layout) {
138184
const EValue e((int64_t)0);
139185
EXPECT_TRUE(e.isInt());
140186
const exec_aten::Layout l = e.to<exec_aten::Layout>();
141187
EXPECT_EQ(l, exec_aten::Layout::Strided);
142188
}
143189

144-
TEST(TestEValue, Device) {
190+
TEST_F(EValueTest, Device) {
145191
const EValue e((int64_t)0);
146192
EXPECT_TRUE(e.isInt());
147193
const exec_aten::Device d = e.to<exec_aten::Device>();
148194
EXPECT_TRUE(d.is_cpu());
149195
}
150196

151-
TEST(TestEValue, BoxedEvalueList) {
197+
TEST_F(EValueTest, BoxedEvalueList) {
152198
// create fake values table to point to
153199
EValue values[3] = {
154200
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
@@ -164,7 +210,7 @@ TEST(TestEValue, BoxedEvalueList) {
164210
EXPECT_EQ(unwrapped[2], 3);
165211
}
166212

167-
TEST(TestEValue, toOptionalTensorList) {
213+
TEST_F(EValueTest, toOptionalTensorList) {
168214
// create list, empty evalue ctor gets tag::None
169215
EValue values[2] = {EValue(), EValue()};
170216
EValue* values_p[2] = {&values[0], &values[1]};
@@ -185,3 +231,51 @@ TEST(TestEValue, toOptionalTensorList) {
185231
EXPECT_FALSE(x[0].has_value());
186232
EXPECT_FALSE(x[1].has_value());
187233
}
234+
235+
TEST_F(EValueTest, ConstructFromUniquePtr) {
236+
TensorFactory<ScalarType::Float> tf;
237+
auto tensor_ptr = std::make_unique<exec_aten::Tensor>(tf.ones({2, 3}));
238+
239+
EValue evalue(std::move(tensor_ptr));
240+
241+
EXPECT_TRUE(evalue.isTensor());
242+
EXPECT_EQ(evalue.toTensor().dim(), 2);
243+
EXPECT_EQ(evalue.toTensor().numel(), 6);
244+
245+
EValue evalue2(std::make_unique<exec_aten::Tensor>(tf.ones({4, 5})));
246+
247+
EXPECT_TRUE(evalue2.isTensor());
248+
EXPECT_EQ(evalue2.toTensor().dim(), 2);
249+
EXPECT_EQ(evalue2.toTensor().numel(), 20);
250+
}
251+
252+
TEST_F(EValueTest, ConstructFromSharedPtr) {
253+
TensorFactory<ScalarType::Float> tf;
254+
auto tensor_ptr = std::make_shared<exec_aten::Tensor>(tf.ones({4, 5}));
255+
256+
EValue evalue(tensor_ptr);
257+
258+
EXPECT_TRUE(evalue.isTensor());
259+
EXPECT_EQ(evalue.toTensor().dim(), 2);
260+
EXPECT_EQ(evalue.toTensor().numel(), 20);
261+
}
262+
263+
TEST_F(EValueTest, ConstructFromTensorWrapper) {
264+
TensorFactory<ScalarType::Float> tf;
265+
TensorWrapper tensor_wrapper(tf.ones({4, 5}));
266+
267+
EValue evalue(tensor_wrapper);
268+
269+
EXPECT_TRUE(evalue.isTensor());
270+
EXPECT_EQ(evalue.toTensor().dim(), 2);
271+
EXPECT_EQ(evalue.toTensor().numel(), 20);
272+
}
273+
274+
TEST_F(EValueTest, ConstructFromNullPtrAborts) {
275+
std::unique_ptr<exec_aten::Tensor> null_ptr;
276+
277+
ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
278+
}
279+
280+
} // namespace executor
281+
} // namespace torch

0 commit comments

Comments
 (0)