Skip to content

Commit bf72446

Browse files
committed
Update on "Replace pytree_assert with production pytree_check. Remove pytree_unreachable"
When handling untrusted input, it's not appropriate to use debug-only checks; we should be checking in prod as these are not programmer errors. pytree_unreachable was similarly being used for input validation. Differential Revision: [D68166301](https://our.internmc.facebook.com/intern/diff/D68166301/) [ghstack-poisoned]
2 parents 5f5692f + 1f56465 commit bf72446

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

extension/pytree/pybindings.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ class PyTree {
145145
} else if (py::isinstance<py::int_>(key)) {
146146
s.key(i) = py::cast<int32_t>(key);
147147
} else {
148-
throw std::runtime_err("invalid key in pytree dict; must be int or string");
148+
throw std::runtime_err(
149+
"invalid key in pytree dict; must be int or string");
149150
}
150151

151152
flatten_internal(dict[key], leaves, s[i]);
@@ -224,7 +225,8 @@ class PyTree {
224225
case Key::Kind::Str:
225226
return py::cast(key.as_str()).release();
226227
default:
227-
throw std::runtime_error("invalid key in pytree dict; must be int or string");
228+
throw std::runtime_error(
229+
"invalid key in pytree dict; must be int or string");
228230
}
229231
}();
230232
dict[py_key] = unflatten_internal(spec[i], leaves_it);

extension/pytree/test/test_pytree.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,28 @@
1111
#include <gtest/gtest.h>
1212
#include <string>
1313

14+
using ::executorch::extension::pytree::arr;
1415
using ::executorch::extension::pytree::ContainerHandle;
1516
using ::executorch::extension::pytree::Key;
1617
using ::executorch::extension::pytree::Kind;
1718
using ::executorch::extension::pytree::unflatten;
1819

1920
using Leaf = int32_t;
2021

22+
TEST(PyTreeTest, ArrBasic) {
23+
arr<int> x(5);
24+
ASSERT_EQ(x.size(), 5);
25+
EXPECT_THROW(x.at(5), std::out_of_range);
26+
for (int ii = 0; ii < x.size(); ++ii) {
27+
x[ii] = 2 * ii;
28+
}
29+
int idx = 0;
30+
for (const auto item : x) {
31+
EXPECT_EQ(item, 2 * idx);
32+
++idx;
33+
}
34+
}
35+
2136
TEST(PyTreeTest, List) {
2237
Leaf items[2] = {11, 12};
2338
std::string spec = "L2#1#1($,$)";

0 commit comments

Comments
 (0)