Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 6fc3b12

Browse files
committed
gpt2 tokenizer
1 parent 014db90 commit 6fc3b12

File tree

2 files changed

+385
-0
lines changed

2 files changed

+385
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
package co.huggingface.android_transformers.gpt2.tokenization
2+
3+
internal val byteEncoder: Map<Int, String> by lazy {
4+
hashMapOf<Int, String>().apply {
5+
put(33, "!")
6+
put(34, "\"")
7+
put(35, "#")
8+
put(36, "$")
9+
put(37, "%")
10+
put(38, "&")
11+
put(39, "'")
12+
put(40, "(")
13+
put(41, ")")
14+
put(42, "*")
15+
put(43, "+")
16+
put(44, ",")
17+
put(45, "-")
18+
put(46, ".")
19+
put(47, "/")
20+
put(48, "0")
21+
put(49, "1")
22+
put(50, "2")
23+
put(51, "3")
24+
put(52, "4")
25+
put(53, "5")
26+
put(54, "6")
27+
put(55, "7")
28+
put(56, "8")
29+
put(57, "9")
30+
put(58, ":")
31+
put(59, ";")
32+
put(60, "<")
33+
put(61, "=")
34+
put(62, ">")
35+
put(63, "?")
36+
put(64, "@")
37+
put(65, "A")
38+
put(66, "B")
39+
put(67, "C")
40+
put(68, "D")
41+
put(69, "E")
42+
put(70, "F")
43+
put(71, "G")
44+
put(72, "H")
45+
put(73, "I")
46+
put(74, "J")
47+
put(75, "K")
48+
put(76, "L")
49+
put(77, "M")
50+
put(78, "N")
51+
put(79, "O")
52+
put(80, "P")
53+
put(81, "Q")
54+
put(82, "R")
55+
put(83, "S")
56+
put(84, "T")
57+
put(85, "U")
58+
put(86, "V")
59+
put(87, "W")
60+
put(88, "X")
61+
put(89, "Y")
62+
put(90, "Z")
63+
put(91, "[")
64+
put(92, "\\")
65+
put(93, "]")
66+
put(94, "^")
67+
put(95, "_")
68+
put(96, "`")
69+
put(97, "a")
70+
put(98, "b")
71+
put(99, "c")
72+
put(100, "d")
73+
put(101, "e")
74+
put(102, "f")
75+
put(103, "g")
76+
put(104, "h")
77+
put(105, "i")
78+
put(106, "j")
79+
put(107, "k")
80+
put(108, "l")
81+
put(109, "m")
82+
put(110, "n")
83+
put(111, "o")
84+
put(112, "p")
85+
put(113, "q")
86+
put(114, "r")
87+
put(115, "s")
88+
put(116, "t")
89+
put(117, "u")
90+
put(118, "v")
91+
put(119, "w")
92+
put(120, "x")
93+
put(121, "y")
94+
put(122, "z")
95+
put(123, "{")
96+
put(124, "|")
97+
put(125, "}")
98+
put(126, "~")
99+
put(161, "\u00a1")
100+
put(162, "\u00a2")
101+
put(163, "\u00a3")
102+
put(164, "\u00a4")
103+
put(165, "\u00a5")
104+
put(166, "\u00a6")
105+
put(167, "\u00a7")
106+
put(168, "\u00a8")
107+
put(169, "\u00a9")
108+
put(170, "\u00aa")
109+
put(171, "\u00ab")
110+
put(172, "\u00ac")
111+
put(174, "\u00ae")
112+
put(175, "\u00af")
113+
put(176, "\u00b0")
114+
put(177, "\u00b1")
115+
put(178, "\u00b2")
116+
put(179, "\u00b3")
117+
put(180, "\u00b4")
118+
put(181, "\u00b5")
119+
put(182, "\u00b6")
120+
put(183, "\u00b7")
121+
put(184, "\u00b8")
122+
put(185, "\u00b9")
123+
put(186, "\u00ba")
124+
put(187, "\u00bb")
125+
put(188, "\u00bc")
126+
put(189, "\u00bd")
127+
put(190, "\u00be")
128+
put(191, "\u00bf")
129+
put(192, "\u00c0")
130+
put(193, "\u00c1")
131+
put(194, "\u00c2")
132+
put(195, "\u00c3")
133+
put(196, "\u00c4")
134+
put(197, "\u00c5")
135+
put(198, "\u00c6")
136+
put(199, "\u00c7")
137+
put(200, "\u00c8")
138+
put(201, "\u00c9")
139+
put(202, "\u00ca")
140+
put(203, "\u00cb")
141+
put(204, "\u00cc")
142+
put(205, "\u00cd")
143+
put(206, "\u00ce")
144+
put(207, "\u00cf")
145+
put(208, "\u00d0")
146+
put(209, "\u00d1")
147+
put(210, "\u00d2")
148+
put(211, "\u00d3")
149+
put(212, "\u00d4")
150+
put(213, "\u00d5")
151+
put(214, "\u00d6")
152+
put(215, "\u00d7")
153+
put(216, "\u00d8")
154+
put(217, "\u00d9")
155+
put(218, "\u00da")
156+
put(219, "\u00db")
157+
put(220, "\u00dc")
158+
put(221, "\u00dd")
159+
put(222, "\u00de")
160+
put(223, "\u00df")
161+
put(224, "\u00e0")
162+
put(225, "\u00e1")
163+
put(226, "\u00e2")
164+
put(227, "\u00e3")
165+
put(228, "\u00e4")
166+
put(229, "\u00e5")
167+
put(230, "\u00e6")
168+
put(231, "\u00e7")
169+
put(232, "\u00e8")
170+
put(233, "\u00e9")
171+
put(234, "\u00ea")
172+
put(235, "\u00eb")
173+
put(236, "\u00ec")
174+
put(237, "\u00ed")
175+
put(238, "\u00ee")
176+
put(239, "\u00ef")
177+
put(240, "\u00f0")
178+
put(241, "\u00f1")
179+
put(242, "\u00f2")
180+
put(243, "\u00f3")
181+
put(244, "\u00f4")
182+
put(245, "\u00f5")
183+
put(246, "\u00f6")
184+
put(247, "\u00f7")
185+
put(248, "\u00f8")
186+
put(249, "\u00f9")
187+
put(250, "\u00fa")
188+
put(251, "\u00fb")
189+
put(252, "\u00fc")
190+
put(253, "\u00fd")
191+
put(254, "\u00fe")
192+
put(255, "\u00ff")
193+
put(0, "\u0100")
194+
put(1, "\u0101")
195+
put(2, "\u0102")
196+
put(3, "\u0103")
197+
put(4, "\u0104")
198+
put(5, "\u0105")
199+
put(6, "\u0106")
200+
put(7, "\u0107")
201+
put(8, "\u0108")
202+
put(9, "\u0109")
203+
put(10, "\u010a")
204+
put(11, "\u010b")
205+
put(12, "\u010c")
206+
put(13, "\u010d")
207+
put(14, "\u010e")
208+
put(15, "\u010f")
209+
put(16, "\u0110")
210+
put(17, "\u0111")
211+
put(18, "\u0112")
212+
put(19, "\u0113")
213+
put(20, "\u0114")
214+
put(21, "\u0115")
215+
put(22, "\u0116")
216+
put(23, "\u0117")
217+
put(24, "\u0118")
218+
put(25, "\u0119")
219+
put(26, "\u011a")
220+
put(27, "\u011b")
221+
put(28, "\u011c")
222+
put(29, "\u011d")
223+
put(30, "\u011e")
224+
put(31, "\u011f")
225+
put(32, "\u0120")
226+
put(127, "\u0121")
227+
put(128, "\u0122")
228+
put(129, "\u0123")
229+
put(130, "\u0124")
230+
put(131, "\u0125")
231+
put(132, "\u0126")
232+
put(133, "\u0127")
233+
put(134, "\u0128")
234+
put(135, "\u0129")
235+
put(136, "\u012a")
236+
put(137, "\u012b")
237+
put(138, "\u012c")
238+
put(139, "\u012d")
239+
put(140, "\u012e")
240+
put(141, "\u012f")
241+
put(142, "\u0130")
242+
put(143, "\u0131")
243+
put(144, "\u0132")
244+
put(145, "\u0133")
245+
put(146, "\u0134")
246+
put(147, "\u0135")
247+
put(148, "\u0136")
248+
put(149, "\u0137")
249+
put(150, "\u0138")
250+
put(151, "\u0139")
251+
put(152, "\u013a")
252+
put(153, "\u013b")
253+
put(154, "\u013c")
254+
put(155, "\u013d")
255+
put(156, "\u013e")
256+
put(157, "\u013f")
257+
put(158, "\u0140")
258+
put(159, "\u0141")
259+
put(160, "\u0142")
260+
put(173, "\u0143")
261+
}
262+
}
263+
264+
internal val byteDecoder by lazy {
265+
byteEncoder.entries.associateBy({ it.value }) { it.key }
266+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package co.huggingface.android_transformers.gpt2.tokenization
2+
3+
import android.content.Context
4+
import android.util.JsonReader
5+
import java.io.BufferedReader
6+
import java.io.InputStreamReader
7+
8+
private const val VOCAB_PATH = "gpt2-vocab.json"
9+
private const val MERGES_PATH = "gpt2-merges.txt"
10+
11+
class GPT2Tokenizer(private val context: Context) {
12+
private val encoder: Map<String, Int>
13+
private val decoder: Map<Int, String>
14+
private val bpeRanks: Map<Pair<String, String>, Int>
15+
private val encodeRegex = Regex("""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
16+
17+
init {
18+
encoder = hashMapOf<String, Int>().apply {
19+
val vocabStream = context.assets.open(VOCAB_PATH)
20+
vocabStream.use {
21+
val vocabReader = JsonReader(InputStreamReader(it, "UTF-8"))
22+
vocabReader.beginObject();
23+
while (vocabReader.hasNext()) {
24+
val key = vocabReader.nextName()
25+
val value = vocabReader.nextInt()
26+
put(key, value)
27+
}
28+
vocabReader.close()
29+
}
30+
}
31+
32+
decoder = encoder.entries.associateBy({ it.value }, { it.key })
33+
34+
bpeRanks = hashMapOf<Pair<String, String>, Int>().apply {
35+
val mergesStream = context.assets.open(MERGES_PATH)
36+
mergesStream.use { stream ->
37+
val mergesReader = BufferedReader(InputStreamReader(stream))
38+
mergesReader.useLines { seq ->
39+
seq.drop(1).forEachIndexed { i, s ->
40+
val list = s.split(" ")
41+
val keyTuple = list[0] to list[1]
42+
put(keyTuple, i)
43+
}
44+
}
45+
}
46+
}
47+
}
48+
49+
fun decode(tokens: List<Int>): String {
50+
val text = tokens.joinToString("") { decoder.getOrDefault(it, "") }
51+
val utfCodepoints = text.map { byteDecoder[it.toString()]!! }
52+
return String(utfCodepoints.toIntArray(), 0, utfCodepoints.size)
53+
}
54+
55+
fun encode(text: String): List<Int> {
56+
val tokens = encodeRegex.findAll(text).map {
57+
it.value.codePoints()
58+
.boxed()
59+
.map { byteEncoder[it]!! }
60+
.toArray()
61+
.joinToString("")
62+
}
63+
64+
return tokens
65+
.map { bpe(it) }
66+
.flatten()
67+
.map { encoder[it]!! }
68+
.toList()
69+
}
70+
71+
private fun bpe(token: String): List<String> {
72+
if (token.length <= 1) return listOf(token)
73+
74+
var word = token.map { it.toString() }
75+
var pairs = getPairs(word)
76+
77+
while (true) {
78+
val (first, second) = pairs.minBy { bpeRanks.getOrDefault(it, Int.MAX_VALUE) } ?: break
79+
80+
var i = 0
81+
val newWord = mutableListOf<String>()
82+
while (i < word.size) {
83+
val j = word.subList(i, word.size).indexOf(first)
84+
if (j != -1) {
85+
newWord.addAll(word.subList(i, j))
86+
i = j
87+
} else {
88+
newWord.addAll(word.subList(i, word.size))
89+
break
90+
}
91+
92+
if (word[i] == first && i < word.size-1 && word[i+1] == second) {
93+
newWord.add(first+second)
94+
i += 2
95+
} else {
96+
newWord.add(word[i])
97+
i += 1
98+
}
99+
}
100+
101+
word = newWord
102+
if (word.size == 1) {
103+
break
104+
} else {
105+
pairs = getPairs(word)
106+
}
107+
}
108+
109+
return word
110+
}
111+
112+
private fun getPairs(word: List<String>): Set<Pair<String, String>> {
113+
return mutableSetOf<Pair<String, String>>().apply {
114+
for (i in 0 until word.size-1) {
115+
add(word[i] to word[i+1])
116+
}
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)