Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 6aea0e8

Browse files
committed
move gpt2 in its own app
1 parent 2d17620 commit 6aea0e8

File tree

29 files changed

+446
-10
lines changed

29 files changed

+446
-10
lines changed

gpt2/build.gradle

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
apply plugin: 'com.android.application'
2+
apply plugin: 'kotlin-android'
3+
apply plugin: 'kotlin-android-extensions'
4+
5+
android {
6+
compileSdkVersion 29
7+
buildToolsVersion "29.0.2"
8+
9+
10+
defaultConfig {
11+
applicationId "co.huggingface.android_transformers.gpt2"
12+
minSdkVersion 26
13+
targetSdkVersion 29
14+
versionCode 1
15+
versionName "1.0"
16+
17+
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
18+
}
19+
20+
aaptOptions {
21+
noCompress "tflite"
22+
}
23+
24+
kotlinOptions {
25+
jvmTarget = JavaVersion.VERSION_1_8
26+
}
27+
28+
buildTypes {
29+
release {
30+
minifyEnabled false
31+
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
32+
}
33+
}
34+
35+
sourceSets { main { assets.srcDirs = ['src/main/assets', 'src/main/assets/'] } }
36+
}
37+
38+
apply from: 'download.gradle'
39+
40+
dependencies {
41+
implementation fileTree(dir: 'libs', include: ['*.jar'])
42+
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version"
43+
implementation 'androidx.appcompat:appcompat:1.1.0'
44+
implementation 'androidx.core:core-ktx:1.1.0'
45+
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
46+
testImplementation 'junit:junit:4.12'
47+
androidTestImplementation 'androidx.test.ext:junit:1.1.1'
48+
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
49+
50+
implementation 'org.tensorflow:tensorflow-lite:2.0.0'
51+
52+
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.3.2'
53+
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.3.2'
54+
implementation 'androidx.lifecycle:lifecycle-livedata-ktx:2.2.0-rc02'
55+
implementation 'androidx.lifecycle:lifecycle-viewmodel-ktx:2.2.0-rc02'
56+
implementation 'androidx.activity:activity-ktx:1.1.0-rc02'
57+
}

gpt2/download.gradle

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
apply plugin: 'de.undercouch.download'
2+
3+
task downloadLiteModel {
4+
def downloadFiles = [
5+
// 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-384.tflite': 'model.tflite',
6+
]
7+
downloadFiles.each { key, value ->
8+
download {
9+
src key
10+
dest "$projectDir/src/main/assets/" + value
11+
overwrite false
12+
}
13+
}
14+
}

gpt2/proguard-rules.pro

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Add project specific ProGuard rules here.
2+
# You can control the set of applied configuration files using the
3+
# proguardFiles setting in build.gradle.
4+
#
5+
# For more details, see
6+
# http://developer.android.com/guide/developing/tools/proguard.html
7+
8+
# If your project uses WebView with JS, uncomment the following
9+
# and specify the fully qualified class name to the JavaScript interface
10+
# class:
11+
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12+
# public *;
13+
#}
14+
15+
# Uncomment this to preserve the line number information for
16+
# debugging stack traces.
17+
#-keepattributes SourceFile,LineNumberTable
18+
19+
# If you keep the line number information, uncomment this to
20+
# hide the original source file name.
21+
#-renamesourcefileattribute SourceFile
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package co.huggingface.android_transformers.gpt2
2+
3+
import androidx.test.platform.app.InstrumentationRegistry
4+
import androidx.test.ext.junit.runners.AndroidJUnit4
5+
6+
import org.junit.Test
7+
import org.junit.runner.RunWith
8+
9+
import org.junit.Assert.*
10+
11+
/**
12+
* Instrumented test, which will execute on an Android device.
13+
*
14+
* See [testing documentation](http://d.android.com/tools/testing).
15+
*/
16+
@RunWith(AndroidJUnit4::class)
17+
class ExampleInstrumentedTest {
18+
@Test
19+
fun useAppContext() {
20+
// Context of the app under test.
21+
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
22+
assertEquals("co.huggingface.android_transformers.gpt2", appContext.packageName)
23+
}
24+
}

gpt2/src/main/AndroidManifest.xml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
3+
xmlns:tools="http://schemas.android.com/tools"
4+
package="co.huggingface.android_transformers.gpt2">
5+
6+
<application
7+
android:allowBackup="true"
8+
android:icon="@mipmap/ic_launcher"
9+
android:label="@string/app_name"
10+
android:roundIcon="@mipmap/ic_launcher_round"
11+
android:supportsRtl="true"
12+
android:theme="@style/AppTheme"
13+
tools:ignore="GoogleAppIndexingWarning">
14+
<activity android:name=".MainActivity">
15+
<intent-filter>
16+
<action android:name="android.intent.action.MAIN" />
17+
18+
<category android:name="android.intent.category.LAUNCHER" />
19+
</intent-filter>
20+
</activity>
21+
</application>
22+
23+
</manifest>

gpt2/src/main/assets/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.tflite
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package co.huggingface.android_transformers.gpt2
2+
3+
import androidx.appcompat.app.AppCompatActivity
4+
import android.os.Bundle
5+
import android.os.Handler
6+
import android.os.HandlerThread
7+
import androidx.activity.viewModels
8+
import androidx.lifecycle.observe
9+
10+
class MainActivity : AppCompatActivity() {
11+
private val gpt2: co.huggingface.android_transformers.gpt2.ml.GPT2Client by viewModels()
12+
private val handlerThread by lazy { HandlerThread("GPT2Client") }
13+
private val handler by lazy {
14+
handlerThread.start()
15+
Handler(handlerThread.looper)
16+
}
17+
18+
override fun onCreate(savedInstanceState: Bundle?) {
19+
super.onCreate(savedInstanceState)
20+
setContentView(R.layout.activity_main)
21+
22+
handler.post {
23+
gpt2.init()
24+
val generation = gpt2.generate("My name is")
25+
26+
runOnUiThread {
27+
generation.observe(this) {
28+
print(it)
29+
}
30+
}
31+
}
32+
}
33+
34+
override fun onDestroy() {
35+
super.onDestroy()
36+
handlerThread.quit()
37+
}
38+
}

app/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt renamed to gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package co.huggingface.android_transformers.gpt2.ml
33
import android.app.Application
44
import android.util.JsonReader
55
import androidx.lifecycle.AndroidViewModel
6+
import androidx.lifecycle.liveData
7+
import androidx.lifecycle.viewModelScope
68
import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer
9+
import kotlinx.coroutines.Dispatchers
710
import org.tensorflow.lite.Interpreter
811
import java.io.BufferedReader
912
import java.io.FileInputStream
@@ -17,7 +20,7 @@ private const val SEQUENCE_LENGTH = 64
1720
private const val VOCAB_SIZE = 50257
1821
private const val NUM_HEAD = 12
1922
private const val NUM_LITE_THREADS = 4
20-
private const val MODEL_PATH = "gpt2-64.tflite"
23+
private const val MODEL_PATH = "model.tflite"
2124
private const val VOCAB_PATH = "gpt2-vocab.json"
2225
private const val MERGES_PATH = "gpt2-merges.txt"
2326

@@ -44,12 +47,10 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
4447
if (!::tflite.isInitialized) {
4548
tflite = loadModel()
4649
}
47-
48-
generate("My name is")
4950
}
5051

51-
fun generate(text: String, nbTokens: Int = 10) { // = liveData<String>(
52-
//viewModelScope.coroutineContext+Dispatchers.Default) {
52+
fun generate(text: String, nbTokens: Int = 10) = liveData<String>(
53+
viewModelScope.coroutineContext+Dispatchers.Default) {
5354

5455
val tokens = tokenizer.encode(text)
5556
repeat (nbTokens) {
@@ -85,8 +86,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) {
8586

8687
tokens.add(nextToken)
8788
val decodedToken = tokenizer.decode(listOf(nextToken))
88-
print(decodedToken)
89-
// emit(decodedToken)
89+
emit(decodedToken)
9090
}
9191
}
9292

0 commit comments

Comments
 (0)