Skip to content

Commit 378f0db

Browse files
Grain Teamcopybara-github
authored andcommitted
# [grain] Add checks for parameters passed to grain.experimental.index_shuffle that will throw an exception that can by caught by the Python interpreter
Currently when an invalid `rounds` parameter is passed, an assert fails which crashes the interpreter. If index > max_index, an infinite loop may result. PiperOrigin-RevId: 838873254
1 parent 31b9760 commit 378f0db

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

grain/_src/python/experimental/index_shuffle/index_shuffle.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,18 @@ limitations under the License.
2222
// https://eprint.iacr.org/2013/404
2323
// and following recommendations in
2424
// https://nsacyber.github.io/simon-speck/implementations/ImplementationGuide1.1.pdf.
25-
// However we use a single fixed key size and support arbtitary block sizes.
26-
// Further we fixed the number of rounds in the Feistel structuro to be always
25+
// However we use a single fixed key size and support arbitrary block sizes.
26+
// Further we fixed the number of rounds in the Feistel structure to be always
2727
// 4. This reduces the computational cost and still gives good shuffle behavior.
2828
//
29-
// Warning: Given the modifications descripted above this implementation should
30-
// not be used for application that require cryptograhic secure RNGs.
29+
// Warning: Given the modifications description above this implementation should
30+
// not be used for application that require cryptographic secure RNGs.
3131

3232
#include "grain/_src/python/experimental/index_shuffle/index_shuffle.h"
3333

3434
#include <assert.h>
3535

3636
#include <algorithm>
37-
#include <array>
3837
#include <bitset>
3938
#include <cmath>
4039
#include <cstdint>
@@ -120,6 +119,8 @@ uint64_t index_shuffle(const uint64_t index, const uint64_t max_index,
120119
assert(block_size > 0 && block_size % 2 == 0 && block_size <= 64);
121120
// At least 4 rounds and number of rounds must be even.
122121
assert(rounds >= 4 && rounds % 2 == 0);
122+
// Assert the index is bounded by [0, max_index].
123+
assert(index >= 0 && index <= max_index);
123124
#define HANDLE_BLOCK_SIZE(B) \
124125
case B: \
125126
return impl::index_shuffle<B>(index, max_index, seed, rounds);
Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <pybind11/pybind11.h>
22

3+
#include <cstdint>
4+
35
#include "grain/_src/python/experimental/index_shuffle/index_shuffle.h"
46

57
namespace py = pybind11;
@@ -9,7 +11,17 @@ PYBIND11_MODULE(index_shuffle_module, m) {
911
"Returns the position of `index` in a permutation of [0, ..., "
1012
"max_index].";
1113
m.doc() = kDoc;
12-
m.def("index_shuffle", &::grain::random::index_shuffle, kDoc,
13-
py::arg("index"), py::arg("max_index"), py::arg("seed"),
14-
py::arg("rounds"));
14+
m.def(
15+
"index_shuffle",
16+
[](uint64_t index, uint64_t max_index, uint32_t seed, uint32_t rounds) {
17+
if (rounds < 4 || rounds % 2 != 0) {
18+
throw py::value_error("rounds must be an even integer >= 4");
19+
}
20+
if (index < 0 || index > max_index) {
21+
throw py::value_error("index must be in [0, max_index]");
22+
}
23+
return grain::random::index_shuffle(index, max_index, seed, rounds);
24+
},
25+
kDoc, py::arg("index"), py::arg("max_index"), py::arg("seed"),
26+
py::arg("rounds"));
1527
}

0 commit comments

Comments
 (0)