Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip-out/
*.model
tokenizer.json
*.pte
*.ptd
!test_bpe_tokenizer.bin
!test_tiktoken_tokenizer.model

Expand Down
2 changes: 2 additions & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ non_fbcode_target(_kind = fb_android_library,
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
],
autoglob = False,
Expand Down
6 changes: 6 additions & 0 deletions extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY)
)
endif()

if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
target_sources(executorch_jni PRIVATE jni/jni_layer_training.cpp jni/log.cpp)
list(APPEND link_libraries extension_training)
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_TRAINING=1)
endif()

if(EXECUTORCH_BUILD_LLAMA_JNI)
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
list(APPEND link_libraries llama_runner llava_runner)
Expand Down
11 changes: 11 additions & 0 deletions extension/android/executorch_android/android_test_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ which "${PYTHON_EXECUTABLE}"
BASEDIR=$(dirname "$(realpath $0)")

prepare_add() {
pushd "${BASEDIR}/../../../"
python3 -m test.models.export_program --modules "ModuleAdd" --outdir "${BASEDIR}/src/androidTest/resources/"
popd
}

prepare_xor() {
pushd "${BASEDIR}/../../training/"
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/"
mv "${BASEDIR}/src/androidTest/resources/xor.pte" "${BASEDIR}/src/androidTest/resources/xor_full.pte"
python3 -m examples.XOR.export_model --outdir "${BASEDIR}/src/androidTest/resources/" --external
popd
}

prepare_tinyllama() {
Expand Down Expand Up @@ -43,5 +53,6 @@ prepare_vision() {
}

prepare_add
prepare_xor
prepare_tinyllama
prepare_vision
2 changes: 2 additions & 0 deletions extension/android/executorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ dependencies {
implementation libs.core.ktx
testImplementation 'junit:junit:4.12'
testImplementation 'org.assertj:assertj-core:3.27.2'
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test:rules:1.2.0'
androidTestImplementation 'commons-io:commons-io:2.4'
androidTestImplementation 'org.json:json:20250107'
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
}

import com.vanniktech.maven.publish.SonatypeHost
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.pytorch.executorch

import android.Manifest
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.rule.GrantPermissionRule
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import kotlin.random.Random
import kotlin.test.assertContains

/** Unit tests for [TrainingModule]. */
@RunWith(AndroidJUnit4::class)
class TrainingModuleE2ETest {
@get:Rule
var runtimePermissionRule: GrantPermissionRule =
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testTrainXOR() {
val pteFilePath = "/xor.pte"
val ptdFilePath = "/xor.ptd"

val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val ptdFile = File(getTestFilePath(ptdFilePath))
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
ptdInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
assertContains(params, LIN_WEIGHT)
assertContains(params, LIN_BIAS)
assertContains(params, LIN2_WEIGHT)
assertContains(params, LIN2_BIAS)

val sgd = SGD.create(params, 0.5);
val dataset = listOf<Tensor>(
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
)

val numEpochs = 5000;
var finalLoss = Float.MAX_VALUE

for (i in 0 until numEpochs) {
val inputDex = 2 * Random.nextInt(dataset.size / 2)
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val gradients = module.namedGradients("forward")

if (i == 0) {
Assert.assertEquals(4, gradients.size)
assertContains(gradients, LIN_WEIGHT)
assertContains(gradients, LIN_BIAS)
assertContains(gradients, LIN2_WEIGHT)
assertContains(gradients, LIN2_BIAS)
}

if (i % 500 == 0 || i == numEpochs - 1) {
Log.i(
"testTrainXOR",
String.format(
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
i,
out[0].toTensor().getDataAsFloatArray()[0],
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
}

sgd.step(gradients)

if (i == numEpochs - 1) {
finalLoss = out[0].toTensor().dataAsFloatArray[0]
}
}
Assert.assertTrue(finalLoss < 0.1f)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testTrainXOR_PTEOnly() {
val pteFilePath = "/xor_full.pte"

val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath));
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
assertContains(params, LIN_WEIGHT)
assertContains(params, LIN_BIAS)
assertContains(params, LIN2_WEIGHT)
assertContains(params, LIN2_BIAS)

val sgd = SGD.create(params, 0.5);
val dataset = listOf<Tensor>(
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
)

val numEpochs = 5000;
var finalLoss = Float.MAX_VALUE

for (i in 0 until numEpochs) {
val inputDex = 2 * Random.nextInt(dataset.size / 2)
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val gradients = module.namedGradients("forward")

if (i == 0) {
Assert.assertEquals(4, gradients.size)
assertContains(gradients, LIN_WEIGHT)
assertContains(gradients, LIN_BIAS)
assertContains(gradients, LIN2_WEIGHT)
assertContains(gradients, LIN2_BIAS)
}

if (i % 500 == 0 || i == numEpochs - 1) {
Log.i(
"testTrainXOR_PTEOnly",
String.format(
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
i,
out[0].toTensor().getDataAsFloatArray()[0],
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
}

sgd.step(gradients)

if (i == numEpochs - 1) {
finalLoss = out[0].toTensor().dataAsFloatArray[0]
}
}
Assert.assertTrue(finalLoss < 0.1f)
}

@Test
@Throws(IOException::class)
fun testMissingPteFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
}
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
}

@Test
@Throws(IOException::class)
fun testMissingPtdFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
val pteFilePath = "/xor.pte"
val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
}
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
}

companion object {
private const val LIN_WEIGHT = "net.linear.weight"
private const val LIN_BIAS = "net.linear.bias"
private const val LIN2_WEIGHT = "net.linear2.weight"
private const val LIN2_BIAS = "net.linear2.bias"
private const val MISSING_PTE_NAME = "/missing.pte"
private const val MISSING_PTD_NAME = "/missing.ptd"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
import org.pytorch.executorch.annotations.Experimental;

/**
* Java wrapper for ExecuTorch SGD Optimizer.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class SGD {

static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
// Loads libexecutorch.so from jniLibs
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov);

private SGD(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
mHybridData =
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with the specified parameters and options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @param momentum The momentum value
* @param dampening The dampening value
* @param weightDecay The weight decay value
* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt have to be this diff but would it be more "java-y" to have builder classes?

new SGDBuilder().learning_rate().buildSGD();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good - having an SGDBuilder() sounds like a great follow-up to me.

Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with default options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
}

/**
* Performs a single optimization step using the provided gradients.
*
* @param namedGradients Map of parameter names to gradient tensors
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
}
stepNative(namedGradients);
}

@DoNotStrip
private native void stepNative(Map<String, Tensor> namedGradients);
}
Loading
Loading