@@ -3,7 +3,10 @@ package co.huggingface.android_transformers.gpt2.ml
33import android.app.Application
44import android.util.JsonReader
55import androidx.lifecycle.AndroidViewModel
6+ import androidx.lifecycle.liveData
7+ import androidx.lifecycle.viewModelScope
68import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9+ import kotlinx.coroutines.Dispatchers
710import org.tensorflow.lite.Interpreter
811import java.io.BufferedReader
912import java.io.FileInputStream
@@ -17,7 +20,7 @@ private const val SEQUENCE_LENGTH = 64
1720private const val VOCAB_SIZE = 50257
1821private const val NUM_HEAD = 12
1922private const val NUM_LITE_THREADS = 4
20- private const val MODEL_PATH = " gpt2-64 .tflite"
23+ private const val MODEL_PATH = " model .tflite"
2124private const val VOCAB_PATH = " gpt2-vocab.json"
2225private const val MERGES_PATH = " gpt2-merges.txt"
2326
@@ -44,12 +47,10 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4447 if (! ::tflite.isInitialized) {
4548 tflite = loadModel()
4649 }
47-
48- generate(" My name is" )
4950 }
5051
51- fun generate (text : String , nbTokens : Int = 10) { // = liveData<String>(
52- // viewModelScope.coroutineContext+Dispatchers.Default) {
52+ fun generate (text : String , nbTokens : Int = 10) = liveData<String >(
53+ viewModelScope.coroutineContext+ Dispatchers .Default ) {
5354
5455 val tokens = tokenizer.encode(text)
5556 repeat (nbTokens) {
@@ -85,8 +86,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
8586
8687 tokens.add(nextToken)
8788 val decodedToken = tokenizer.decode(listOf (nextToken))
88- print (decodedToken)
89- // emit(decodedToken)
89+ emit(decodedToken)
9090 }
9191 }
9292
0 commit comments