diff --git a/extension/android/build.gradle b/extension/android/build.gradle index de243154d65..b40f08e0c45 100644 --- a/extension/android/build.gradle +++ b/extension/android/build.gradle @@ -20,6 +20,5 @@ task makeJar(type: Jar) { dependencies { implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' implementation 'com.facebook.soloader:nativeloader:0.10.5' - testImplementation 'junit:junit:4.13.2' } } diff --git a/extension/android_test/add_model.py b/extension/android_test/add_model.py new file mode 100644 index 00000000000..5c7cf4770e2 --- /dev/null +++ b/extension/android_test/add_model.py @@ -0,0 +1,26 @@ +import torch +from executorch.exir import to_edge +from torch.export import export + + +# Start with a PyTorch model that adds two input tensors (matrices) +class Add(torch.nn.Module): + def __init__(self): + super(Add, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +# 1. torch.export: Defines the program with the ATen operator set. +aten_dialect = export(Add(), (torch.ones(1), torch.ones(1))) + +# 2. to_edge: Make optimizations for Edge devices +edge_program = to_edge(aten_dialect) + +# 3. to_executorch: Convert the graph to an ExecuTorch program +executorch_program = edge_program.to_executorch() + +# 4. Save the compiled .pte program +with open("add.pte", "wb") as file: + file.write(executorch_program.buffer) diff --git a/extension/android_test/setup.sh b/extension/android_test/setup.sh index a12f76c1f35..d83aeeebb45 100755 --- a/extension/android_test/setup.sh +++ b/extension/android_test/setup.sh @@ -21,10 +21,13 @@ build_native_library() { -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI="${ANDROID_ABI}" \ -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -B"${CMAKE_OUT}" cmake --build "${CMAKE_OUT}" -j16 --target install @@ -33,6 +36,7 @@ build_native_library() { -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}"/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ -DCMAKE_INSTALL_PREFIX=c"${CMAKE_OUT}" \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ -B"${CMAKE_OUT}"/extension/android @@ -48,6 +52,10 @@ build_jar build_native_library "arm64-v8a" build_native_library "x86_64" build_aar +source ".ci/scripts/test_llama.sh" stories110M cmake fp16 portable ${BUILD_AAR_DIR} popd mkdir -p "$BASEDIR"/src/libs cp "$BUILD_AAR_DIR/executorch.aar" "$BASEDIR"/src/libs/executorch.aar +python add_model.py +mv "add.pte" "$BASEDIR"/src/androidTest/resources/add.pte +unzip -o "$BUILD_AAR_DIR"/model.zip -d "$BASEDIR"/src/androidTest/resources diff --git a/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java new file mode 100644 index 00000000000..940e34d684f --- /dev/null +++ b/extension/android_test/src/androidTest/java/org/pytorch/executorch/LlamaModuleInstrumentationTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import android.os.Environment; +import androidx.test.rule.GrantPermissionRule; +import android.Manifest; +import android.content.Context; +import org.junit.Test; +import org.junit.Before; +import org.junit.Rule; +import org.junit.runner.RunWith; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.ArrayList; +import java.io.IOException; +import java.io.File; +import java.io.FileOutputStream; +import org.junit.runners.JUnit4; +import org.apache.commons.io.FileUtils; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.InstrumentationRegistry; +import org.pytorch.executorch.LlamaModule; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.Module; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; + +/** Unit tests for {@link LlamaModule}. */ +@RunWith(AndroidJUnit4.class) +public class LlamaModuleInstrumentationTest implements LlamaCallback { + private static String TEST_FILE_NAME = "/tinyllama_portable_fp16_h.pte"; + private static String TOKENIZER_FILE_NAME = "/tokenizer.bin"; + private static String TEST_PROMPT = "Hello"; + private static int OK = 0x00; + private static int SEQ_LEN = 32; + + private final List results = new ArrayList<>(); + private final List tokensPerSecond = new ArrayList<>(); + private LlamaModule mModule; + + private static String getTestFilePath(String fileName) { + return InstrumentationRegistry.getInstrumentation().getTargetContext().getExternalCacheDir() + fileName; + } + + @Before + public void setUp() throws IOException { + // copy zipped test resources to local device + File addPteFile = new File(getTestFilePath(TEST_FILE_NAME)); + InputStream inputStream = getClass().getResourceAsStream(TEST_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, addPteFile); + inputStream.close(); + + File tokenizerFile = new File(getTestFilePath(TOKENIZER_FILE_NAME)); + inputStream = getClass().getResourceAsStream(TOKENIZER_FILE_NAME); + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile); + inputStream.close(); + + mModule = new LlamaModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f); + } + + @Rule + public GrantPermissionRule mRuntimePermissionRule = GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE); + + @Test + public void testGenerate() throws IOException, URISyntaxException{ + int loadResult = mModule.load(); + // Check that the model can be load successfully + assertEquals(OK, loadResult); + + mModule.generate(TEST_PROMPT, SEQ_LEN, LlamaModuleInstrumentationTest.this); + assertEquals(results.size(), SEQ_LEN); + assertTrue(tokensPerSecond.get(tokensPerSecond.size() - 1) > 0); + } + + @Test + public void testGenerateAndStop() throws IOException, URISyntaxException{ + int seqLen = 32; + mModule.generate(TEST_PROMPT, SEQ_LEN, new LlamaCallback() { + @Override + public void onResult(String result) { + LlamaModuleInstrumentationTest.this.onResult(result); + mModule.stop(); + } + + @Override + public void onStats(float tps) { + LlamaModuleInstrumentationTest.this.onStats(tps); + } + }); + + int stoppedResultSize = results.size(); + assertTrue(stoppedResultSize < SEQ_LEN); + } + + @Override + public void onResult(String result) { + results.add(result); + } + + @Override + public void onStats(float tps) { + tokensPerSecond.add(tps); + } +} diff --git a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java similarity index 99% rename from extension/android/src/test/java/org/pytorch/executorch/EValueTest.java rename to extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java index 35367883efe..29cabae75fa 100644 --- a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java @@ -129,7 +129,7 @@ public void testOptionalTensorListValue() { Optional.of(Tensor.fromBlob(data[1], shape[1]))); assertTrue(evalue.isOptionalTensorList()); - assertTrue(evalue.toOptionalTensorList()[0].isEmpty()); + assertTrue(!evalue.toOptionalTensorList()[0].isPresent()); assertTrue(evalue.toOptionalTensorList()[1].isPresent()); assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0])); diff --git a/extension/android/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java similarity index 100% rename from extension/android/src/test/java/org/pytorch/executorch/TensorTest.java rename to extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java