|
1 | 1 | package co.huggingface.android_transformers.gpt2.ml |
2 | 2 |
|
3 | 3 | import android.app.Application |
| 4 | +import android.util.JsonReader |
4 | 5 | import androidx.lifecycle.AndroidViewModel |
5 | 6 | import androidx.lifecycle.liveData |
6 | 7 | import androidx.lifecycle.viewModelScope |
7 | 8 | import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer |
8 | 9 | import kotlinx.coroutines.Dispatchers |
9 | 10 | import org.tensorflow.lite.Interpreter |
| 11 | +import java.io.BufferedReader |
10 | 12 | import java.io.FileInputStream |
| 13 | +import java.io.InputStreamReader |
11 | 14 | import java.nio.channels.FileChannel |
12 | 15 |
|
13 | 16 | private const val SEQUENCE_LENGTH = 64 |
14 | | -private const val NUM_LITE_THREADS = 4; |
15 | | -private const val MODEL_PATH = "model.tflite" |
| 17 | +private const val VOCAB_SIZE = 50257 |
| 18 | +private const val NUM_HEAD = 12 |
| 19 | +private const val NUM_LITE_THREADS = 4 |
| 20 | +private const val MODEL_PATH = "gpt2-64.tflite" |
| 21 | +private const val VOCAB_PATH = "gpt2-vocab.json" |
| 22 | +private const val MERGES_PATH = "gpt2-merges.txt" |
| 23 | + |
| 24 | +private typealias Predictions = Array<Array<FloatArray>> |
16 | 25 |
|
17 | 26 | class GPT2Client(application: Application) : AndroidViewModel(application) { |
18 | | - private val tokenizer = GPT2Tokenizer(application) |
| 27 | + private lateinit var tokenizer: GPT2Tokenizer |
19 | 28 | private lateinit var tflite: Interpreter |
20 | 29 |
|
| 30 | + fun init() { |
| 31 | + if (!::tokenizer.isInitialized) { |
| 32 | + val encoder = loadEncoder() |
| 33 | + val decoder = encoder.entries.associateBy({ it.value }, { it.key }) |
| 34 | + val bpeRanks = loadBpeRanks() |
| 35 | + |
| 36 | + tokenizer = GPT2Tokenizer(encoder, decoder, bpeRanks) |
| 37 | + } |
| 38 | + |
| 39 | + if (!::tflite.isInitialized) { |
| 40 | + tflite = loadModel() |
| 41 | + } |
| 42 | + |
| 43 | + generate("My name is") |
| 44 | + } |
| 45 | + |
| 46 | + fun generate(text: String, nbTokens: Int = 10) { // = liveData<String>( |
| 47 | + //viewModelScope.coroutineContext+Dispatchers.Default) { |
| 48 | + |
| 49 | + var tokens = tokenizer.encode(text) |
| 50 | + repeat (nbTokens) { |
| 51 | + val maxTokens = tokens.takeLast(SEQUENCE_LENGTH).toIntArray() |
| 52 | + val paddedTokens = maxTokens + IntArray(SEQUENCE_LENGTH - maxTokens.size) |
| 53 | + val inputIds = Array(1) { paddedTokens } |
21 | 54 |
|
| 55 | + val predictions: Predictions = Array(1) { Array(SEQUENCE_LENGTH) { FloatArray(VOCAB_SIZE) } } |
| 56 | + val outputs = mutableMapOf<Int, Any>(0 to predictions) |
| 57 | + |
| 58 | + tflite.runForMultipleInputsOutputs(arrayOf(inputIds), outputs) |
| 59 | + val outputLogits = predictions[0][maxTokens.size-1] |
| 60 | + val nextToken = outputLogits.argmax() |
| 61 | + |
| 62 | + tokens.add(nextToken) |
| 63 | + val decodedToken = tokenizer.decode(listOf(nextToken)) |
| 64 | + print(decodedToken) |
| 65 | +// emit(decodedToken) |
| 66 | + } |
| 67 | + } |
22 | 68 |
|
23 | | -// fun generate(text: String, nbTokens: Int = 10) = liveData<Pair<String, Double>>( |
24 | | -// viewModelScope.coroutineContext+Dispatchers.Default) { |
25 | | -// |
26 | | -// var tokens = tokenizer.encode(text) |
27 | | -// for (i in 0 until nbTokens) { |
28 | | -// val maxTokens = tokens.takeLast(SEQUENCE_LENGTH) |
29 | | -// val inputIds = tokens.takeLast(SEQUENCE_LENGTH) + IntArray(SEQUENCE_LENGTH - maxTokens.size).toList() |
30 | | -// |
31 | | -// tflite.runForMultipleInputsOutputs(); |
32 | | -// } |
33 | | -// |
34 | | -// |
35 | | -// } |
36 | | - |
37 | | - private fun loadModel() { |
38 | | - val assetFileDescriptor = this.getApplication<Application>().assets.openFd(MODEL_PATH) |
39 | | - assetFileDescriptor.use { |
| 69 | + private fun loadModel(): Interpreter { |
| 70 | + val assetFileDescriptor = getApplication<Application>().assets.openFd(MODEL_PATH) |
| 71 | + return assetFileDescriptor.use { |
40 | 72 | val fileChannel = FileInputStream(assetFileDescriptor.fileDescriptor).channel |
41 | 73 | val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, it.startOffset, it.declaredLength) |
42 | 74 |
|
43 | | - val opts = Interpreter.Options(); |
44 | | - opts.setNumThreads(NUM_LITE_THREADS); |
| 75 | + val opts = Interpreter.Options() |
| 76 | + opts.setNumThreads(NUM_LITE_THREADS) |
| 77 | + return@use Interpreter(modelBuffer, opts) |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + private fun loadEncoder(): Map<String, Int> { |
| 82 | + return hashMapOf<String, Int>().apply { |
| 83 | + val vocabStream = getApplication<Application>().assets.open(VOCAB_PATH) |
| 84 | + vocabStream.use { |
| 85 | + val vocabReader = JsonReader(InputStreamReader(it, "UTF-8")) |
| 86 | + vocabReader.beginObject() |
| 87 | + while (vocabReader.hasNext()) { |
| 88 | + val key = vocabReader.nextName() |
| 89 | + val value = vocabReader.nextInt() |
| 90 | + put(key, value) |
| 91 | + } |
| 92 | + vocabReader.close() |
| 93 | + } |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + private fun loadBpeRanks(): Map<Pair<String, String>, Int> { |
| 98 | + return hashMapOf<Pair<String, String>, Int>().apply { |
| 99 | + val mergesStream = getApplication<Application>().assets.open(MERGES_PATH) |
| 100 | + mergesStream.use { stream -> |
| 101 | + val mergesReader = BufferedReader(InputStreamReader(stream)) |
| 102 | + mergesReader.useLines { seq -> |
| 103 | + seq.drop(1).forEachIndexed { i, s -> |
| 104 | + val list = s.split(" ") |
| 105 | + val keyTuple = list[0] to list[1] |
| 106 | + put(keyTuple, i) |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +private fun FloatArray.argmax(): Int { |
| 115 | + var bestIndex = 0 |
| 116 | + repeat(size) { |
| 117 | + if (this[it] > this[bestIndex]) { |
| 118 | + bestIndex = it |
45 | 119 | } |
46 | 120 | } |
| 121 | + |
| 122 | + return bestIndex |
47 | 123 | } |
0 commit comments