Skip to content

Commit 320c2fa

Browse files
committed
Update android test package
1 parent 4d14c0d commit 320c2fa

File tree

11 files changed

+806
-12
lines changed

11 files changed

+806
-12
lines changed

.ci/scripts/build_android_instrumentation.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ fi
1313
which "${PYTHON_EXECUTABLE}"
1414

1515
build_android_test() {
16-
pushd extension/android_test
17-
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew testDebugUnitTest
18-
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest
16+
mkdir -p extension/android/executorch_android/src/androidTest/resources
17+
cp extension/module/test/resources/add.pte extension/android/executorch_android/src/androidTest/resources
18+
pushd extension/android
19+
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:testDebugUnitTest
20+
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:assembleAndroidTest
1921
popd
2022
}
2123

@@ -24,8 +26,7 @@ collect_artifacts_to_be_uploaded() {
2426
# Collect Java library test
2527
JAVA_LIBRARY_TEST_DIR="${ARTIFACTS_DIR_NAME}/library_test_dir"
2628
mkdir -p "${JAVA_LIBRARY_TEST_DIR}"
27-
cp extension/android_test/build/outputs/apk/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
28-
cp extension/android_test/build/outputs/apk/androidTest/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
29+
cp executorch_android/build/outputs/apk/androidTest/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
2930
}
3031

3132
main() {

.github/workflows/_android.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ jobs:
9595
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug.apk
9696
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug-androidTest.apk
9797
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/fp32-xnnpack-custom/model.zip
98-
curl -o android-test-debug.apk https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/library_test_dir/executorch-debug.apk
9998
curl -o android-test-debug-androidTest.apk https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/library_test_dir/executorch-debug-androidTest.apk
10099
unzip model.zip
101100
mv *.pte model.pte

.github/workflows/pull.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
- runner: linux.arm64.2xlarge
6161
docker-image: executorch-ubuntu-22.04-clang12
6262
# TODO: Need to figure out why buck2 doesnt work on Graviton instances.
63-
- runner: linux.arm64.2xlarge
63+
- runner: linux.arm64.2xlarge
6464
build-tool: buck2
6565
fail-fast: false
6666
with:
@@ -420,7 +420,6 @@ jobs:
420420
permissions:
421421
id-token: write
422422
contents: read
423-
needs: test-llama-runner-linux
424423

425424
unittest:
426425
uses: ./.github/workflows/_unittest.yml

build/run_android_emulator.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ adb push model.pte /data/local/tmp/llama
2828
adb push tokenizer.bin /data/local/tmp/llama
2929
adb shell am instrument -w -r com.example.executorchllamademo.test/androidx.test.runner.AndroidJUnitRunner
3030

31-
adb uninstall org.pytorch.executorch || true
3231
adb uninstall org.pytorch.executorch.test || true
33-
adb install -t android-test-debug.apk
3432
adb install -t android-test-debug-androidTest.apk
3533

3634
adb shell am instrument -w -r org.pytorch.executorch.test/androidx.test.runner.AndroidJUnitRunner

extension/android/executorch_android/build.gradle

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,30 @@ android {
1616
compileSdk = 34
1717

1818
defaultConfig {
19-
minSdk = 19
19+
minSdk = 23
2020

21-
testInstrumentationRunner = "android.support.test.runner.AndroidJUnitRunner"
21+
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
2222
}
2323

2424
compileOptions {
2525
sourceCompatibility = JavaVersion.VERSION_1_8
2626
targetCompatibility = JavaVersion.VERSION_1_8
2727
}
28+
29+
sourceSets {
30+
androidTest {
31+
resources.srcDirs += [ 'src/androidTest/resources' ]
32+
}
33+
}
2834
}
2935

3036
dependencies {
3137
implementation 'com.facebook.fbjni:fbjni-java-only:0.5.1'
3238
implementation 'com.facebook.soloader:nativeloader:0.10.5'
39+
testImplementation 'junit:junit:4.12'
40+
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
41+
androidTestImplementation 'androidx.test:rules:1.2.0'
42+
androidTestImplementation 'commons-io:commons-io:2.4'
3343
}
3444

3545
import com.vanniktech.maven.publish.SonatypeHost
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.fail;
16+
17+
import android.os.Environment;
18+
import androidx.test.rule.GrantPermissionRule;
19+
import android.Manifest;
20+
import android.content.Context;
21+
import org.junit.Test;
22+
import org.junit.Before;
23+
import org.junit.Rule;
24+
import org.junit.runner.RunWith;
25+
import java.io.InputStream;
26+
import java.net.URI;
27+
import java.net.URISyntaxException;
28+
import java.util.List;
29+
import java.util.ArrayList;
30+
import java.io.IOException;
31+
import java.io.File;
32+
import java.io.FileOutputStream;
33+
import org.junit.runners.JUnit4;
34+
import org.apache.commons.io.FileUtils;
35+
import androidx.test.ext.junit.runners.AndroidJUnit4;
36+
import androidx.test.InstrumentationRegistry;
37+
38+
/** Unit tests for {@link LlamaModule}. */
39+
@RunWith(AndroidJUnit4.class)
40+
public class LlamaModuleInstrumentationTest implements LlamaCallback {
41+
private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte";
42+
private static String TOKENIZER_FILE_NAME = "/tokenizer.bin";
43+
private static String TEST_PROMPT = "Hello";
44+
private static int OK = 0x00;
45+
private static int SEQ_LEN = 32;
46+
47+
private final List<String> results = new ArrayList<>();
48+
private final List<Float> tokensPerSecond = new ArrayList<>();
49+
private LlamaModule mModule;
50+
51+
private static String getTestFilePath(String fileName) {
52+
return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName;
53+
}
54+
55+
@Before
56+
public void setUp() throws IOException {
57+
// copy zipped test resources to local device
58+
File addPteFile = new File(getTestFilePath(TEST_FILE_NAME));
59+
InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME);
60+
FileUtils.copyInputStreamToFile(inputStream, addPteFile);
61+
inputStream.close();
62+
63+
File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME));
64+
inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME);
65+
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile);
66+
inputStream.close();
67+
68+
mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f);
69+
}
70+
71+
@Rule
72+
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
73+
74+
@Test
75+
public void testGenerate() throws IOException, URISyntaxException{
76+
int loadResult = mModule.load();
77+
// Check that the model can be load successfully
78+
assertEquals(OK, loadResult);
79+
80+
mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this);
81+
assertEquals(results.size(), SEQ_LEN);
82+
assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0);
83+
}
84+
85+
@Test
86+
public void testGenerateAndStop() throws IOException, URISyntaxException{
87+
int seqLen = 32;
88+
mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() {
89+
@Override
90+
public void onResult(String result) {
91+
LlamaModuleInstrumentationTest.this.onResult(result);
92+
mModule.stop();
93+
}
94+
95+
@Override
96+
public void onStats(float tps) {
97+
LlamaModuleInstrumentationTest.this.onStats(tps);
98+
}
99+
});
100+
101+
int stoppedResultSize = results.size();
102+
assertTrue(stoppedResultSize < SEQ_LEN);
103+
}
104+
105+
@Override
106+
public void onResult(String result) {
107+
results.add(result);
108+
}
109+
110+
@Override
111+
public void onStats(float tps) {
112+
tokensPerSecond.add(tps);
113+
}
114+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
import static org.junit.Assert.fail;
16+
17+
import android.os.Environment;
18+
import androidx.test.rule.GrantPermissionRule;
19+
import android.Manifest;
20+
import android.content.Context;
21+
import org.junit.Test;
22+
import org.junit.Before;
23+
import org.junit.Rule;
24+
import org.junit.runner.RunWith;
25+
import java.io.InputStream;
26+
import java.net.URI;
27+
import java.net.URISyntaxException;
28+
import java.io.IOException;
29+
import java.io.File;
30+
import java.io.FileOutputStream;
31+
import org.junit.runners.JUnit4;
32+
import org.apache.commons.io.FileUtils;
33+
import androidx.test.ext.junit.runners.AndroidJUnit4;
34+
import androidx.test.InstrumentationRegistry;
35+
36+
/** Unit tests for {@link Module}. */
37+
@RunWith(AndroidJUnit4.class)
38+
public class ModuleInstrumentationTest {
39+
private static String TEST_FILE_NAME = "/add.pte";
40+
private static String MISSING_FILE_NAME = "/missing.pte";
41+
private static String NON_PTE_FILE_NAME = "/test.txt";
42+
private static String FORWARD_METHOD = "forward";
43+
private static String NONE_METHOD = "none";
44+
private static int OK = 0x00;
45+
private static int INVALID_ARGUMENT = 0x12;
46+
private static int ACCESS_FAILED = 0x22;
47+
48+
private static String getTestFilePath(String fileName) {
49+
return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName;
50+
}
51+
52+
@Before
53+
public void setUp() throws IOException {
54+
// copy zipped test resources to local device
55+
File addPteFile = new File(getTestFilePath(TEST_FILE_NAME));
56+
InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME);
57+
FileUtils.copyInputStreamToFile(inputStream, addPteFile);
58+
inputStream.close();
59+
60+
File nonPteFile = new File(getTestFilePath(NON_PTE_FILE_NAME));
61+
inputStream = getClass().getResourceAsStream(NON_PTE_FILE_NAME);
62+
FileUtils.copyInputStreamToFile(inputStream, nonPteFile);
63+
inputStream.close();
64+
}
65+
66+
@Rule
67+
public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE);
68+
69+
@Test
70+
public void testModuleLoadAndForward() throws IOException, URISyntaxException{
71+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
72+
73+
EValue[] results = module.forward();
74+
assertTrue(results[0].isTensor());
75+
}
76+
77+
@Test
78+
public void testModuleLoadMethodAndForward() throws IOException{
79+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
80+
81+
int loadMethod = module.loadMethod(FORWARD_METHOD);
82+
assertEquals(loadMethod, OK);
83+
84+
EValue[] results = module.forward();
85+
assertTrue(results[0].isTensor());
86+
}
87+
88+
@Test
89+
public void testModuleLoadForwardExplicit() throws IOException{
90+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
91+
92+
EValue[] results = module.execute(FORWARD_METHOD);
93+
assertTrue(results[0].isTensor());
94+
}
95+
96+
@Test
97+
public void testModuleLoadNonExistantFile() throws IOException{
98+
Module module = Module.load(getTestFilePath(MISSING_FILE_NAME));
99+
100+
EValue[] results = module.forward();
101+
assertEquals(null, results);
102+
}
103+
104+
@Test
105+
public void testModuleLoadMethodNonExistantFile() throws IOException{
106+
Module module = Module.load(getTestFilePath(MISSING_FILE_NAME));
107+
108+
int loadMethod = module.loadMethod(FORWARD_METHOD);
109+
assertEquals(loadMethod, ACCESS_FAILED);
110+
}
111+
112+
@Test
113+
public void testModuleLoadMethodNonExistantMethod() throws IOException{
114+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
115+
116+
int loadMethod = module.loadMethod(NONE_METHOD);
117+
assertEquals(loadMethod, INVALID_ARGUMENT);
118+
}
119+
120+
@Test
121+
public void testNonPteFile() throws IOException{
122+
Module module = Module.load(getTestFilePath(NON_PTE_FILE_NAME));
123+
124+
int loadMethod = module.loadMethod(FORWARD_METHOD);
125+
assertEquals(loadMethod, INVALID_ARGUMENT);
126+
}
127+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
3+
<application>
4+
</application>
5+
<instrumentation
6+
android:name="androidx.test.runner.AndroidJUnitRunner"
7+
android:targetPackage="org.pytorch.executorch"
8+
android:label="Tests for ExecuTorch Modules" />
9+
</manifest>

0 commit comments

Comments
 (0)