Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 02e24d8

Browse files
committed
gpt2 generation with topk
1 parent 29466c7 commit 02e24d8

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

app/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@ package co.huggingface.android_transformers.gpt2.ml
33
import android.app.Application
44
import android.util.JsonReader
55
import androidx.lifecycle.AndroidViewModel
6-
import androidx.lifecycle.liveData
7-
import androidx.lifecycle.viewModelScope
86
import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9-
import kotlinx.coroutines.Dispatchers
107
import org.tensorflow.lite.Interpreter
118
import java.io.BufferedReader
129
import java.io.FileInputStream
1310
import java.io.InputStreamReader
1411
import java.nio.channels.FileChannel
12+
import kotlin.math.exp
13+
import kotlin.math.min
14+
import kotlin.random.Random
1515

1616
private const val SEQUENCE_LENGTH = 64
1717
private const val VOCAB_SIZE = 50257
@@ -23,10 +23,15 @@ private const val MERGES_PATH = "gpt2-merges.txt"
2323

2424
private typealias Predictions = Array<Array<FloatArray>>
2525

26+
enum class GPT2StrategyEnum { GREEDY, TOPK }
27+
data class GPT2Strategy(val strategy: GPT2StrategyEnum, val value: Int = 0)
28+
2629
class GPT2Client(application: Application) : AndroidViewModel(application) {
2730
private lateinit var tokenizer: GPT2Tokenizer
2831
private lateinit var tflite: Interpreter
2932

33+
var strategy = GPT2Strategy(GPT2StrategyEnum.TOPK, 40)
34+
3035
fun init() {
3136
if (!::tokenizer.isInitialized) {
3237
val encoder = loadEncoder()
@@ -46,7 +51,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4651
fun generate(text: String, nbTokens: Int = 10) { // = liveData<String>(
4752
//viewModelScope.coroutineContext+Dispatchers.Default) {
4853

49-
var tokens = tokenizer.encode(text)
54+
val tokens = tokenizer.encode(text)
5055
repeat (nbTokens) {
5156
val maxTokens = tokens.takeLast(SEQUENCE_LENGTH).toIntArray()
5257
val paddedTokens = maxTokens + IntArray(SEQUENCE_LENGTH - maxTokens.size)
@@ -57,7 +62,26 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
5762

5863
tflite.runForMultipleInputsOutputs(arrayOf(inputIds), outputs)
5964
val outputLogits = predictions[0][maxTokens.size-1]
60-
val nextToken = outputLogits.argmax()
65+
66+
val nextToken: Int = when (strategy.strategy) {
67+
GPT2StrategyEnum.TOPK -> {
68+
val finalTopK = min(strategy.value, outputLogits.size)
69+
val filteredLogits = outputLogits
70+
.mapIndexed { index, fl -> (index to fl) }
71+
.sortedBy { it.second }
72+
.takeWhile { it.second < finalTopK }
73+
74+
// Softmax computation on filtered logits
75+
val maxLogitValue = outputLogits.max()!!
76+
val logitsExp = filteredLogits.map { exp(it.second - maxLogitValue) }
77+
val sumExp = logitsExp.sum()
78+
val probs = logitsExp.map { it.div(sumExp) }
79+
80+
val logitsIndexes = filteredLogits.map { it.first }
81+
sample(logitsIndexes, probs)
82+
}
83+
else -> outputLogits.argmax()
84+
}
6185

6286
tokens.add(nextToken)
6387
val decodedToken = tokenizer.decode(listOf(nextToken))
@@ -111,6 +135,25 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
111135
}
112136
}
113137

138+
private fun randomIndex(probs: List<Float>): Int {
139+
val rnd = Random.nextFloat()
140+
var acc = 0f
141+
142+
probs.forEachIndexed { i, fl ->
143+
acc += fl
144+
if (rnd < acc) {
145+
return i
146+
}
147+
}
148+
149+
return probs.size - 1
150+
}
151+
152+
private fun sample(indexes: List<Int>, probs: List<Float>): Int {
153+
val i = randomIndex(probs)
154+
return indexes[i]
155+
}
156+
114157
private fun FloatArray.argmax(): Int {
115158
var bestIndex = 0
116159
repeat(size) {

app/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class GPT2Tokenizer(
1313
}
1414

1515
fun encode(text: String): MutableList<Int> {
16-
val tokens = encodeRegex.findAll(text).map {
17-
it.value.codePoints()
16+
val tokens = encodeRegex.findAll(text).map { result ->
17+
result.value.codePoints()
1818
.boxed()
1919
.map { byteEncoder[it]!! }
2020
.toArray()

0 commit comments

Comments
 (0)