Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 29466c7

Browse files
committed
basic working greedy gpt2
1 parent 0720353 commit 29466c7

File tree

2 files changed

+104
-68
lines changed

2 files changed

+104
-68
lines changed
Lines changed: 98 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,123 @@
11
package co.huggingface.android_transformers.gpt2.ml
22

33
import android.app.Application
4+
import android.util.JsonReader
45
import androidx.lifecycle.AndroidViewModel
56
import androidx.lifecycle.liveData
67
import androidx.lifecycle.viewModelScope
78
import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
89
import kotlinx.coroutines.Dispatchers
910
import org.tensorflow.lite.Interpreter
11+
import java.io.BufferedReader
1012
import java.io.FileInputStream
13+
import java.io.InputStreamReader
1114
import java.nio.channels.FileChannel
1215

1316
private const val SEQUENCE_LENGTH = 64
14-
private const val NUM_LITE_THREADS = 4;
15-
private const val MODEL_PATH = "model.tflite"
17+
private const val VOCAB_SIZE = 50257
18+
private const val NUM_HEAD = 12
19+
private const val NUM_LITE_THREADS = 4
20+
private const val MODEL_PATH = "gpt2-64.tflite"
21+
private const val VOCAB_PATH = "gpt2-vocab.json"
22+
private const val MERGES_PATH = "gpt2-merges.txt"
23+
24+
private typealias Predictions = Array<Array<FloatArray>>
1625

1726
class GPT2Client(application: Application) : AndroidViewModel(application) {
18-
private val tokenizer = GPT2Tokenizer(application)
27+
private lateinit var tokenizer: GPT2Tokenizer
1928
private lateinit var tflite: Interpreter
2029

30+
fun init() {
31+
if (!::tokenizer.isInitialized) {
32+
val encoder = loadEncoder()
33+
val decoder = encoder.entries.associateBy({ it.value }, { it.key })
34+
val bpeRanks = loadBpeRanks()
35+
36+
tokenizer = GPT2Tokenizer(encoder, decoder, bpeRanks)
37+
}
38+
39+
if (!::tflite.isInitialized) {
40+
tflite = loadModel()
41+
}
42+
43+
generate("My name is")
44+
}
45+
46+
fun generate(text: String, nbTokens: Int = 10) { // = liveData<String>(
47+
//viewModelScope.coroutineContext+Dispatchers.Default) {
48+
49+
var tokens = tokenizer.encode(text)
50+
repeat (nbTokens) {
51+
val maxTokens = tokens.takeLast(SEQUENCE_LENGTH).toIntArray()
52+
val paddedTokens = maxTokens + IntArray(SEQUENCE_LENGTH - maxTokens.size)
53+
val inputIds = Array(1) { paddedTokens }
2154

55+
val predictions: Predictions = Array(1) { Array(SEQUENCE_LENGTH) { FloatArray(VOCAB_SIZE) } }
56+
val outputs = mutableMapOf<Int, Any>(0 to predictions)
57+
58+
tflite.runForMultipleInputsOutputs(arrayOf(inputIds), outputs)
59+
val outputLogits = predictions[0][maxTokens.size-1]
60+
val nextToken = outputLogits.argmax()
61+
62+
tokens.add(nextToken)
63+
val decodedToken = tokenizer.decode(listOf(nextToken))
64+
print(decodedToken)
65+
// emit(decodedToken)
66+
}
67+
}
2268

23-
// fun generate(text: String, nbTokens: Int = 10) = liveData<Pair<String, Double>>(
24-
// viewModelScope.coroutineContext+Dispatchers.Default) {
25-
//
26-
// var tokens = tokenizer.encode(text)
27-
// for (i in 0 until nbTokens) {
28-
// val maxTokens = tokens.takeLast(SEQUENCE_LENGTH)
29-
// val inputIds = tokens.takeLast(SEQUENCE_LENGTH) + IntArray(SEQUENCE_LENGTH - maxTokens.size).toList()
30-
//
31-
// tflite.runForMultipleInputsOutputs();
32-
// }
33-
//
34-
//
35-
// }
36-
37-
private fun loadModel() {
38-
val assetFileDescriptor = this.getApplication<Application>().assets.openFd(MODEL_PATH)
39-
assetFileDescriptor.use {
69+
private fun loadModel(): Interpreter {
70+
val assetFileDescriptor = getApplication<Application>().assets.openFd(MODEL_PATH)
71+
return assetFileDescriptor.use {
4072
val fileChannel = FileInputStream(assetFileDescriptor.fileDescriptor).channel
4173
val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, it.startOffset, it.declaredLength)
4274

43-
val opts = Interpreter.Options();
44-
opts.setNumThreads(NUM_LITE_THREADS);
75+
val opts = Interpreter.Options()
76+
opts.setNumThreads(NUM_LITE_THREADS)
77+
return@use Interpreter(modelBuffer, opts)
78+
}
79+
}
80+
81+
private fun loadEncoder(): Map<String, Int> {
82+
return hashMapOf<String, Int>().apply {
83+
val vocabStream = getApplication<Application>().assets.open(VOCAB_PATH)
84+
vocabStream.use {
85+
val vocabReader = JsonReader(InputStreamReader(it, "UTF-8"))
86+
vocabReader.beginObject()
87+
while (vocabReader.hasNext()) {
88+
val key = vocabReader.nextName()
89+
val value = vocabReader.nextInt()
90+
put(key, value)
91+
}
92+
vocabReader.close()
93+
}
94+
}
95+
}
96+
97+
private fun loadBpeRanks(): Map<Pair<String, String>, Int> {
98+
return hashMapOf<Pair<String, String>, Int>().apply {
99+
val mergesStream = getApplication<Application>().assets.open(MERGES_PATH)
100+
mergesStream.use { stream ->
101+
val mergesReader = BufferedReader(InputStreamReader(stream))
102+
mergesReader.useLines { seq ->
103+
seq.drop(1).forEachIndexed { i, s ->
104+
val list = s.split(" ")
105+
val keyTuple = list[0] to list[1]
106+
put(keyTuple, i)
107+
}
108+
}
109+
}
110+
}
111+
}
112+
}
113+
114+
private fun FloatArray.argmax(): Int {
115+
var bestIndex = 0
116+
repeat(size) {
117+
if (this[it] > this[bestIndex]) {
118+
bestIndex = it
45119
}
46120
}
121+
122+
return bestIndex
47123
}

app/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,18 @@
11
package co.huggingface.android_transformers.gpt2.tokenization
22

3-
import android.content.Context
4-
import android.util.JsonReader
5-
import java.io.BufferedReader
6-
import java.io.InputStreamReader
7-
8-
private const val VOCAB_PATH = "gpt2-vocab.json"
9-
private const val MERGES_PATH = "gpt2-merges.txt"
10-
11-
class GPT2Tokenizer(private val context: Context) {
12-
private val encoder: Map<String, Int>
13-
private val decoder: Map<Int, String>
14-
private val bpeRanks: Map<Pair<String, String>, Int>
3+
class GPT2Tokenizer(
4+
private val encoder: Map<String, Int>,
5+
private val decoder: Map<Int, String>,
6+
private val bpeRanks: Map<Pair<String, String>, Int>) {
157
private val encodeRegex = Regex("""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
168

17-
init {
18-
encoder = hashMapOf<String, Int>().apply {
19-
val vocabStream = context.assets.open(VOCAB_PATH)
20-
vocabStream.use {
21-
val vocabReader = JsonReader(InputStreamReader(it, "UTF-8"))
22-
vocabReader.beginObject();
23-
while (vocabReader.hasNext()) {
24-
val key = vocabReader.nextName()
25-
val value = vocabReader.nextInt()
26-
put(key, value)
27-
}
28-
vocabReader.close()
29-
}
30-
}
31-
32-
decoder = encoder.entries.associateBy({ it.value }, { it.key })
33-
34-
bpeRanks = hashMapOf<Pair<String, String>, Int>().apply {
35-
val mergesStream = context.assets.open(MERGES_PATH)
36-
mergesStream.use { stream ->
37-
val mergesReader = BufferedReader(InputStreamReader(stream))
38-
mergesReader.useLines { seq ->
39-
seq.drop(1).forEachIndexed { i, s ->
40-
val list = s.split(" ")
41-
val keyTuple = list[0] to list[1]
42-
put(keyTuple, i)
43-
}
44-
}
45-
}
46-
}
47-
}
48-
499
fun decode(tokens: List<Int>): String {
5010
val text = tokens.joinToString("") { decoder.getOrDefault(it, "") }
5111
val utfCodepoints = text.map { byteDecoder[it.toString()]!! }
5212
return String(utfCodepoints.toIntArray(), 0, utfCodepoints.size)
5313
}
5414

55-
fun encode(text: String): List<Int> {
15+
fun encode(text: String): MutableList<Int> {
5616
val tokens = encodeRegex.findAll(text).map {
5717
it.value.codePoints()
5818
.boxed()
@@ -65,7 +25,7 @@ class GPT2Tokenizer(private val context: Context) {
6525
.map { bpe(it) }
6626
.flatten()
6727
.map { encoder[it]!! }
68-
.toList()
28+
.toMutableList()
6929
}
7030

7131
private fun bpe(token: String): List<String> {

0 commit comments

Comments
 (0)