Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 1704ab8

Browse files
Davis Liangleezu
authored andcommitted
Move BERTTokenizer to Cython and add caching support (#921)
1 parent 184a000 commit 1704ab8

File tree

8 files changed

+82
-33
lines changed

8 files changed

+82
-33
lines changed

env/cpu/py3-master.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
dependencies:
44
- python=3.6
55
- pip=18.1
6+
- cython
67
- perl
78
- pylint=2.3.1
89
- flake8

env/cpu/py3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
dependencies:
44
- python=3.6
55
- pip=18.1
6+
- cython
67
- perl
78
- pylint=2.3.1
89
- flake8

env/docker/py3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
dependencies:
44
- python=3.6
55
- pip=18.1
6+
- cython
67
- perl
78
- pylint=1.9.2
89
- flake8

env/gpu/py3-master.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
dependencies:
44
- python=3.6
55
- pip=18.1
6+
- cython
67
- perl
78
- pylint=2.3.1
89
- flake8

env/gpu/py3.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
dependencies:
44
- python=3.6
55
- pip=18.1
6+
- cython
67
- perl
78
- pylint=2.3.1
89
- flake8

setup.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
import shutil
66
import sys
7-
from setuptools import setup, find_packages
7+
from setuptools import setup, find_packages, Extension
88

99

1010
def read(*names, **kwargs):
@@ -30,6 +30,7 @@ def find_version(*file_paths):
3030

3131
requirements = [
3232
'numpy',
33+
'cython'
3334
]
3435

3536
setup(
@@ -53,6 +54,11 @@ def find_version(*file_paths):
5354
package_dir={"": "src"},
5455
zip_safe=True,
5556
include_package_data=True,
57+
setup_requires=[
58+
# Setuptools 18.0 properly handles Cython extensions.
59+
'setuptools>=18.0',
60+
'cython',
61+
],
5662
install_requires=requirements,
5763
extras_require={
5864
'extras': [
@@ -82,4 +88,7 @@ def find_version(*file_paths):
8288
'flaky',
8389
],
8490
},
91+
ext_modules=[
92+
Extension('gluonnlp.data.wordpiece', sources=['src/gluonnlp/data/wordpiece.pyx']),
93+
],
8594
)

src/gluonnlp/data/transforms.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
]
3131

3232
import errno
33+
import functools
3334
import io
3435
import os
3536
import time
@@ -42,7 +43,9 @@
4243
import numpy as np
4344

4445
from ..base import get_home_dir
46+
from ..vocab.vocab import Vocab
4547
from .utils import _extract_archive
48+
from .wordpiece import tokenize as wordpiece_tokenize
4649

4750

4851
class ClipSequence:
@@ -790,14 +793,17 @@ class BERTTokenizer:
790793
791794
Parameters
792795
----------
793-
vocab : gluonnlp.Vocab or None, default None
796+
vocab
794797
Vocabulary for the corpus.
795-
lower : bool, default True
798+
lower
796799
whether the text strips accents and convert to lower case.
797800
If you use the BERT pre-training model,
798801
lower is set to Flase when using the cased model,
799802
otherwise it is set to True.
800-
max_input_chars_per_word : int, default 200
803+
max_input_chars_per_word
804+
lru_cache_size
805+
Maximum size of a least-recently-used cache to speed up tokenization.
806+
Use size of 2**20 for example.
801807
802808
Examples
803809
--------
@@ -812,10 +818,14 @@ class BERTTokenizer:
812818

813819
_special_prefix = '##'
814820

815-
def __init__(self, vocab, lower=True, max_input_chars_per_word=200):
821+
def __init__(self, vocab: Vocab, lower: bool = True, max_input_chars_per_word: int = 200,
822+
lru_cache_size: Optional[int] = None):
816823
self.vocab = vocab
817824
self.max_input_chars_per_word = max_input_chars_per_word
818825
self.basic_tokenizer = BERTBasicTokenizer(lower=lower)
826+
if lru_cache_size:
827+
self._word_to_wordpiece_optimized = functools.lru_cache(maxsize=lru_cache_size)(
828+
self._word_to_wordpiece_optimized)
819829

820830
def __call__(self, sample):
821831
"""
@@ -841,6 +851,10 @@ def _tokenizer(self, text):
841851

842852
return split_tokens
843853

854+
def _word_to_wordpiece_optimized(self, text): # pylint: disable=method-hidden
855+
return wordpiece_tokenize(text, self.vocab, self.vocab.unknown_token,
856+
self.max_input_chars_per_word)
857+
844858
def _tokenize_wordpiece(self, text):
845859
"""Tokenizes a piece of text into its word pieces.
846860
@@ -861,35 +875,14 @@ def _tokenize_wordpiece(self, text):
861875
ret : A list of wordpiece tokens.
862876
"""
863877

878+
# case where text is a single token
879+
whitespace_tokenized_tokens = self.basic_tokenizer._whitespace_tokenize(text)
880+
if len(whitespace_tokenized_tokens) == 1:
881+
return self._word_to_wordpiece_optimized(whitespace_tokenized_tokens[0])
882+
864883
output_tokens = []
865-
for token in self.basic_tokenizer._whitespace_tokenize(text):
866-
chars = list(token)
867-
if len(chars) > self.max_input_chars_per_word:
868-
output_tokens.append(self.vocab.unknown_token)
869-
continue
870-
is_bad = False
871-
start = 0
872-
sub_tokens = []
873-
while start < len(chars):
874-
end = len(chars)
875-
cur_substr = None
876-
while start < end:
877-
substr = ''.join(chars[start:end])
878-
if start > 0:
879-
substr = self._special_prefix + substr
880-
if substr in self.vocab:
881-
cur_substr = substr
882-
break
883-
end -= 1
884-
if cur_substr is None:
885-
is_bad = True
886-
break
887-
sub_tokens.append(cur_substr)
888-
start = end
889-
if is_bad:
890-
output_tokens.append(self.vocab.unknown_token)
891-
else:
892-
output_tokens.extend(sub_tokens)
884+
for token in whitespace_tokenized_tokens:
885+
output_tokens.extend(self._word_to_wordpiece_optimized(token))
893886
return output_tokens
894887

895888
def convert_tokens_to_ids(self, tokens):

src/gluonnlp/data/wordpiece.pyx

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import Dict, List, Tuple
2+
3+
import cython
4+
5+
__all__ = ['tokenize']
6+
7+
8+
def tokenize(text: str, vocab: Dict[str, int], unknown_token: str, max_input_chars_per_word: cython.int = 200):
9+
"""
10+
Cython implementation of single token tokenization. Average latency
11+
decreases to 95ms (from 144ms using original Python code).
12+
"""
13+
output_tokens: List[str] = []
14+
token_size: cython.int = len(text)
15+
if token_size > max_input_chars_per_word:
16+
output_tokens.append(unknown_token)
17+
return output_tokens
18+
is_bad: cython.int = 0
19+
start: cython.int = 0
20+
sub_tokens: List[str] = []
21+
while start < token_size:
22+
end: cython.int = token_size
23+
cur_substr: str = None
24+
while start < end:
25+
substr = text[start:end]
26+
if start > 0:
27+
substr = '##' + substr
28+
if substr in vocab:
29+
cur_substr = substr
30+
break
31+
end -= 1
32+
if cur_substr is None:
33+
is_bad = 1
34+
break
35+
sub_tokens.append(cur_substr)
36+
start = end
37+
if is_bad == 1:
38+
output_tokens.append(unknown_token)
39+
else:
40+
output_tokens.extend(sub_tokens)
41+
42+
return output_tokens

0 commit comments

Comments
 (0)