Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions grain/_src/python/experimental/index_shuffle/index_shuffle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,18 @@ limitations under the License.
// https://eprint.iacr.org/2013/404
// and following recommendations in
// https://nsacyber.github.io/simon-speck/implementations/ImplementationGuide1.1.pdf.
// However we use a single fixed key size and support arbtitary block sizes.
// Further we fixed the number of rounds in the Feistel structuro to be always
// However we use a single fixed key size and support arbitrary block sizes.
// Further we fixed the number of rounds in the Feistel structure to be always
// 4. This reduces the computational cost and still gives good shuffle behavior.
//
// Warning: Given the modifications descripted above this implementation should
// not be used for application that require cryptograhic secure RNGs.
// Warning: Given the modifications description above this implementation should
// not be used for application that require cryptographic secure RNGs.

#include "grain/_src/python/experimental/index_shuffle/index_shuffle.h"

#include <assert.h>

#include <algorithm>
#include <array>
#include <bitset>
#include <cmath>
#include <cstdint>
Expand Down Expand Up @@ -120,6 +119,8 @@ uint64_t index_shuffle(const uint64_t index, const uint64_t max_index,
assert(block_size > 0 && block_size % 2 == 0 && block_size <= 64);
// At least 4 rounds and number of rounds must be even.
assert(rounds >= 4 && rounds % 2 == 0);
// Assert the index is bounded by [0, max_index].
assert(index >= 0 && index <= max_index);
#define HANDLE_BLOCK_SIZE(B) \
case B: \
return impl::index_shuffle<B>(index, max_index, seed, rounds);
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/experimental/index_shuffle/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pybind_extension(
srcs = ["index_shuffle_module.cc"],
deps = [
"//grain/_src/python/experimental/index_shuffle",
"@abseil-cpp//absl/strings",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <pybind11/pybind11.h>

#include <cstdint>

#include "absl/strings/str_cat.h"
#include "grain/_src/python/experimental/index_shuffle/index_shuffle.h"

namespace py = pybind11;
Expand All @@ -9,7 +12,21 @@ PYBIND11_MODULE(index_shuffle_module, m) {
"Returns the position of `index` in a permutation of [0, ..., "
"max_index].";
m.doc() = kDoc;
m.def("index_shuffle", &::grain::random::index_shuffle, kDoc,
py::arg("index"), py::arg("max_index"), py::arg("seed"),
py::arg("rounds"));
m.def(
"index_shuffle",
[](int64_t index, int64_t max_index, uint32_t seed, uint32_t rounds) {
if (rounds < 4 || rounds % 2 != 0) {
throw py::value_error(absl::StrCat(
"rounds must be an even integer >= 4, but got rounds = ",
rounds));
}
if (index < 0 || index > max_index) {
throw py::value_error(absl::StrCat(
"index must be in [0, max_index], but got index = ", index,
" and max_index = ", max_index));
}
return grain::random::index_shuffle(index, max_index, seed, rounds);
},
kDoc, py::arg("index"), py::arg("max_index"), py::arg("seed"),
py::arg("rounds"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def test_index_shuffle_single_record(self):
0, index_shuffle.index_shuffle(index=0, max_index=0, seed=0, rounds=4)
)

def test_index_shuffle_invalid_rounds(self):
regex = r'rounds must be an even integer >= 4'
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=0, max_index=8, seed=33, rounds=2)
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=0, max_index=8, seed=76, rounds=5)

def test_index_shuffle_invalid_index(self):
regex = r'index must be in \[0, max_index\]'
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=-1, max_index=8, seed=33, rounds=4)
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=9, max_index=8, seed=76, rounds=4)


if __name__ == '__main__':
absltest.main()