@@ -3,15 +3,15 @@ package co.huggingface.android_transformers.gpt2.ml
33import android.app.Application
44import android.util.JsonReader
55import androidx.lifecycle.AndroidViewModel
6- import androidx.lifecycle.liveData
7- import androidx.lifecycle.viewModelScope
86import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9- import kotlinx.coroutines.Dispatchers
107import org.tensorflow.lite.Interpreter
118import java.io.BufferedReader
129import java.io.FileInputStream
1310import java.io.InputStreamReader
1411import java.nio.channels.FileChannel
12+ import kotlin.math.exp
13+ import kotlin.math.min
14+ import kotlin.random.Random
1515
1616private const val SEQUENCE_LENGTH = 64
1717private const val VOCAB_SIZE = 50257
@@ -23,10 +23,15 @@ private const val MERGES_PATH = "gpt2-merges.txt"
2323
2424private typealias Predictions = Array <Array <FloatArray >>
2525
26+ enum class GPT2StrategyEnum { GREEDY , TOPK }
27+ data class GPT2Strategy (val strategy : GPT2StrategyEnum , val value : Int = 0 )
28+
2629class GPT2Client (application : Application ) : AndroidViewModel(application) {
2730 private lateinit var tokenizer: GPT2Tokenizer
2831 private lateinit var tflite: Interpreter
2932
33+ var strategy = GPT2Strategy (GPT2StrategyEnum .TOPK , 40 )
34+
3035 fun init () {
3136 if (! ::tokenizer.isInitialized) {
3237 val encoder = loadEncoder()
@@ -46,7 +51,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4651 fun generate (text : String , nbTokens : Int = 10) { // = liveData<String>(
4752 // viewModelScope.coroutineContext+Dispatchers.Default) {
4853
49- var tokens = tokenizer.encode(text)
54+ val tokens = tokenizer.encode(text)
5055 repeat (nbTokens) {
5156 val maxTokens = tokens.takeLast(SEQUENCE_LENGTH ).toIntArray()
5257 val paddedTokens = maxTokens + IntArray (SEQUENCE_LENGTH - maxTokens.size)
@@ -57,7 +62,26 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
5762
5863 tflite.runForMultipleInputsOutputs(arrayOf(inputIds), outputs)
5964 val outputLogits = predictions[0 ][maxTokens.size- 1 ]
60- val nextToken = outputLogits.argmax()
65+
66+ val nextToken: Int = when (strategy.strategy) {
67+ GPT2StrategyEnum .TOPK -> {
68+ val finalTopK = min(strategy.value, outputLogits.size)
69+ val filteredLogits = outputLogits
70+ .mapIndexed { index, fl -> (index to fl) }
71+ .sortedBy { it.second }
72+ .takeWhile { it.second < finalTopK }
73+
74+ // Softmax computation on filtered logits
75+ val maxLogitValue = outputLogits.max()!!
76+ val logitsExp = filteredLogits.map { exp(it.second - maxLogitValue) }
77+ val sumExp = logitsExp.sum()
78+ val probs = logitsExp.map { it.div(sumExp) }
79+
80+ val logitsIndexes = filteredLogits.map { it.first }
81+ sample(logitsIndexes, probs)
82+ }
83+ else -> outputLogits.argmax()
84+ }
6185
6286 tokens.add(nextToken)
6387 val decodedToken = tokenizer.decode(listOf (nextToken))
@@ -111,6 +135,25 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
111135 }
112136}
113137
138+ private fun randomIndex (probs : List <Float >): Int {
139+ val rnd = Random .nextFloat()
140+ var acc = 0f
141+
142+ probs.forEachIndexed { i, fl ->
143+ acc + = fl
144+ if (rnd < acc) {
145+ return i
146+ }
147+ }
148+
149+ return probs.size - 1
150+ }
151+
152+ private fun sample (indexes : List <Int >, probs : List <Float >): Int {
153+ val i = randomIndex(probs)
154+ return indexes[i]
155+ }
156+
114157private fun FloatArray.argmax (): Int {
115158 var bestIndex = 0
116159 repeat(size) {
0 commit comments