@@ -530,6 +530,9 @@ def test_token_pickle():
530530 assert token == token2
531531
532532
533+ _MAX_COUNTER = 18446744073709551615
534+
535+
533536def test_vocab ():
534537 special_tokens = ["<blank>" , "<s>" , "</s>" ]
535538 vocab = pyonmttok .Vocab (special_tokens = special_tokens )
@@ -557,9 +560,9 @@ def test_vocab():
557560 }
558561
559562 assert vocab .counters == [
560- 18446744073709551615 ,
561- 18446744073709551615 ,
562- 18446744073709551615 ,
563+ _MAX_COUNTER ,
564+ _MAX_COUNTER ,
565+ _MAX_COUNTER ,
563566 2 ,
564567 1 ,
565568 ]
@@ -628,3 +631,19 @@ def test_vocab_default_id(tokens, default_id, expected_default_id):
628631 vocab .default_id = default_id
629632 assert vocab .default_id == expected_default_id
630633 assert vocab .lookup_token ("oov" ) == expected_default_id
634+
635+
636+ def test_vocab_pickle ():
637+ vocab = pyonmttok .build_vocab_from_tokens (
638+ ["a" , "b" , "a" , "a" , "c" , "c" ], special_tokens = ["z" ]
639+ )
640+ vocab .default_id = 0
641+
642+ data = pickle .dumps (vocab )
643+ vocab_clone = pickle .loads (data )
644+
645+ assert vocab_clone is not vocab
646+ assert len (vocab ) == 4
647+ assert vocab .ids_to_tokens == ["z" , "a" , "b" , "c" ]
648+ assert vocab .default_id == 0
649+ assert vocab .counters == [_MAX_COUNTER , 3 , 1 , 2 ]
0 commit comments