Skip to content

Commit d3b5ad8

Browse files
authored
Implement pickle support for Vocab objects (#303)
1 parent 1f611ae commit d3b5ad8

File tree

5 files changed

+58
-10
lines changed

5 files changed

+58
-10
lines changed

bindings/python/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ vocab.__getitem__(token: str) -> int # Implements: vocab["hello"]
264264
# If a tokenizer is not set, the text is split on spaces.
265265
vocab.add_from_text(text: str, tokenizer: Optional[pyonmttok.Tokenizer] = None) -> None
266266
vocab.add_from_file(path: str, tokenizer: Optional[pyonmttok.Tokenizer] = None) -> None
267-
vocab.add_token(token: str) -> None
267+
vocab.add_token(token: str, count: int = 1) -> None
268268

269269
vocab.resize(maximum_size: int = 0, minimum_frequency: int = 1) -> None
270270

bindings/python/pyonmttok/Python.cc

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ PYBIND11_MODULE(_ext, m)
754754
.def_property_readonly("ids_to_tokens", &onmt::Vocab::ids_to_tokens)
755755
.def_property_readonly("counters", &onmt::Vocab::counters)
756756

757-
.def("add_token", &onmt::Vocab::add_token, py::arg("token"))
757+
.def("add_token", &onmt::Vocab::add_token, py::arg("token"), py::arg("count")=1)
758758

759759
.def("add_from_text",
760760
[](onmt::Vocab& vocab,
@@ -792,5 +792,30 @@ PYBIND11_MODULE(_ext, m)
792792
[](const onmt::Vocab& vocab, const py::object& dict) {
793793
return onmt::Vocab(vocab);
794794
})
795+
796+
.def(py::pickle(
797+
[](const onmt::Vocab& vocab) {
798+
return py::make_tuple(
799+
/*version=*/1,
800+
vocab.ids_to_tokens(),
801+
vocab.counters(),
802+
vocab.get_default_id());
803+
},
804+
[](py::tuple t) {
805+
if (t.size() != 4 || t[0].cast<unsigned int>() != 1)
806+
throw std::runtime_error("Invalid pickle data");
807+
808+
auto tokens = t[1].cast<std::vector<std::string>>();
809+
auto counters = t[2].cast<std::vector<size_t>>();
810+
auto default_id = t[3].cast<size_t>();
811+
812+
onmt::Vocab vocab;
813+
vocab.set_default_id(default_id);
814+
815+
for (size_t i = 0; i < tokens.size(); ++i)
816+
vocab.add_token(std::move(tokens[i]), counters[i]);
817+
818+
return vocab;
819+
}));
795820
;
796821
}

bindings/python/test/test.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ def test_token_pickle():
530530
assert token == token2
531531

532532

533+
_MAX_COUNTER = 18446744073709551615
534+
535+
533536
def test_vocab():
534537
special_tokens = ["<blank>", "<s>", "</s>"]
535538
vocab = pyonmttok.Vocab(special_tokens=special_tokens)
@@ -557,9 +560,9 @@ def test_vocab():
557560
}
558561

559562
assert vocab.counters == [
560-
18446744073709551615,
561-
18446744073709551615,
562-
18446744073709551615,
563+
_MAX_COUNTER,
564+
_MAX_COUNTER,
565+
_MAX_COUNTER,
563566
2,
564567
1,
565568
]
@@ -628,3 +631,19 @@ def test_vocab_default_id(tokens, default_id, expected_default_id):
628631
vocab.default_id = default_id
629632
assert vocab.default_id == expected_default_id
630633
assert vocab.lookup_token("oov") == expected_default_id
634+
635+
636+
def test_vocab_pickle():
637+
vocab = pyonmttok.build_vocab_from_tokens(
638+
["a", "b", "a", "a", "c", "c"], special_tokens=["z"]
639+
)
640+
vocab.default_id = 0
641+
642+
data = pickle.dumps(vocab)
643+
vocab_clone = pickle.loads(data)
644+
645+
assert vocab_clone is not vocab
646+
assert len(vocab) == 4
647+
assert vocab.ids_to_tokens == ["z", "a", "b", "c"]
648+
assert vocab.default_id == 0
649+
assert vocab.counters == [_MAX_COUNTER, 3, 1, 2]

include/onmt/Vocab.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ namespace onmt
6161
return _frequencies;
6262
}
6363

64-
void add_token(std::string token);
64+
void add_token(std::string token, size_t count = 1);
6565
void add_from_text(const std::string& text, const Tokenizer* tokenizer = nullptr);
6666
void add_from_stream(std::istream& is, const Tokenizer* tokenizer = nullptr);
6767
void resize(size_t maximum_size = 0, size_t minimum_frequency = 1);

src/Vocab.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace onmt
1818
frequency = maximum_frequency;
1919
}
2020

21-
void Vocab::add_token(std::string token)
21+
void Vocab::add_token(std::string token, size_t count)
2222
{
2323
const size_t id = _ids_to_tokens.size();
2424
const auto pair = _tokens_to_ids.emplace(std::move(token), id);
@@ -28,11 +28,15 @@ namespace onmt
2828
if (inserted)
2929
{
3030
_ids_to_tokens.emplace_back(entry.first);
31-
_frequencies.emplace_back(1);
31+
_frequencies.emplace_back(count);
3232
}
33-
else if (_frequencies[entry.second] < maximum_frequency)
33+
else if (_frequencies[entry.second] <= maximum_frequency - count)
3434
{
35-
_frequencies[entry.second]++;
35+
_frequencies[entry.second] += count;
36+
}
37+
else
38+
{
39+
_frequencies[entry.second] = maximum_frequency;
3640
}
3741
}
3842

0 commit comments

Comments
 (0)