Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pip-out/
*.bin
*.model
tokenizer.json
*.pte
!test_bpe_tokenizer.bin
!test_tiktoken_tokenizer.model

Expand Down
2 changes: 2 additions & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ non_fbcode_target(_kind = fb_android_library,
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
],
autoglob = False,
Expand Down
3 changes: 2 additions & 1 deletion extension/android/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch)
find_package(executorch CONFIG REQUIRED)
target_link_options_shared_lib(executorch)

add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp)
add_library(executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp jni/jni_layer_training.cpp)

set(link_libraries)
list(
Expand All @@ -77,6 +77,7 @@ list(
extension_runner_util
extension_tensor
extension_threadpool
extension_training
fbjni
)

Expand Down
2 changes: 2 additions & 0 deletions extension/android/executorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ dependencies {
implementation libs.core.ktx
testImplementation 'junit:junit:4.12'
testImplementation 'org.assertj:assertj-core:3.27.2'
testImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test:rules:1.2.0'
androidTestImplementation 'commons-io:commons-io:2.4'
androidTestImplementation 'org.json:json:20250107'
androidTestImplementation 'org.jetbrains.kotlin:kotlin-test:1.9.23'
}

import com.vanniktech.maven.publish.SonatypeHost
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.pytorch.executorch

import android.Manifest
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.rule.GrantPermissionRule
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import kotlin.random.Random
import kotlin.test.assertContains

/** Unit tests for [TrainingModule]. */
@RunWith(AndroidJUnit4::class)
class TrainingModuleE2ETest {
@get:Rule
var runtimePermissionRule: GrantPermissionRule =
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testTrainXOR() {
val pteFilePath = "/xor.pte"
val ptdFilePath = "/xor.ptd"

val pteFile = File(getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val ptdFile = File(getTestFilePath(ptdFilePath))
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
ptdInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
assertContains(params, LIN_WEIGHT)
assertContains(params, LIN_BIAS)
assertContains(params, LIN2_WEIGHT)
assertContains(params, LIN2_BIAS)

val sgd = SGD.create(params, 0.5);
val dataset = listOf<Tensor>(
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
)

val numEpochs = 5000;
var finalLoss = Float.MAX_VALUE

for (i in 0 until numEpochs) {
val inputDex = 2 * Random.nextInt(dataset.size / 2)
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val gradients = module.namedGradients("forward")

if (i == 0) {
Assert.assertEquals(4, gradients.size)
assertContains(gradients, LIN_WEIGHT)
assertContains(gradients, LIN_BIAS)
assertContains(gradients, LIN2_WEIGHT)
assertContains(gradients, LIN2_BIAS)
}

if (i % 500 == 0 || i == numEpochs - 1) {
Log.i(
"testTrainXOR",
String.format(
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d",
i,
out[0].toTensor().getDataAsFloatArray()[0],
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
}

sgd.step(gradients)

if (i == numEpochs - 1) {
finalLoss = out[0].toTensor().dataAsFloatArray[0]
}
}
Assert.assertTrue(finalLoss < 0.1f)
}

companion object {
private const val LIN_WEIGHT = "net.linear.weight"
private const val LIN_BIAS = "net.linear.bias"
private const val LIN2_WEIGHT = "net.linear2.weight"
private const val LIN2_BIAS = "net.linear2.bias"
}
}
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
import org.pytorch.executorch.annotations.Experimental;

/**
* Java wrapper for ExecuTorch SGD Optimizer.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class SGD {

static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
// Loads libexecutorch.so from jniLibs
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov);

private SGD(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
mHybridData =
initHybrid(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with the specified parameters and options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @param momentum The momentum factor (default: 0)
* @param dampening The dampening for momentum (default: 0)
* @param weightDecay The weight decay (L2 penalty) (default: 0)
* @param nesterov Whether to use Nesterov momentum (default: false)
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(
Map<String, Tensor> namedParameters,
double learningRate,
double momentum,
double dampening,
double weightDecay,
boolean nesterov) {
return new SGD(namedParameters, learningRate, momentum, dampening, weightDecay, nesterov);
}

/**
* Creates a new SGD optimizer with default options.
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
}

/**
* Performs a single optimization step using the provided gradients.
*
* @param namedGradients Map of parameter names to gradient tensors
*/
public void step(Map<String, Tensor> namedGradients) {
if (!mHybridData.isValid()) {
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
}
stepNative(namedGradients);
}

@DoNotStrip
private native void stepNative(Map<String, Tensor> namedGradients);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import org.pytorch.executorch.annotations.Experimental;

/**
* Java wrapper for ExecuTorch TrainingModule.
*
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
public class TrainingModule {

static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
// Loads libexecutorch.so from jniLibs
NativeLoader.loadLibrary("executorch");
}

private final HybridData mHybridData;

@DoNotStrip
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);

private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
}

/**
* Loads a serialized ExecuTorch module from the specified path on the disk.
*
* @param modelPath path to file that contains the serialized ExecuTorch module.
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
*/
public static TrainingModule load(final String modelPath, final String dataPath) {
File modelFile = new File(modelPath);
if (!modelFile.canRead() || !modelFile.isFile()) {
throw new RuntimeException("Cannot load model path!! " + modelPath);
}
File dataFile = new File(dataPath);
if (!dataFile.canRead() || !dataFile.isFile()) {
throw new RuntimeException("Cannot load data path!! " + dataPath);
}
return new TrainingModule(modelPath, dataPath);
}

/**
* Runs the specified method of this module with the specified arguments.
*
* @param methodName name of the ExecuTorch method to run.
* @param inputs arguments that will be passed to ExecuTorch method.
* @return return value from the method.
*/
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return executeForwardBackwardNative(methodName, inputs);
}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);

public Map<String, Tensor> namedParameters(String methodName) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedParametersNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedParametersNative(String methodName);

public Map<String, Tensor> namedGradients(String methodName) {
if (!mHybridData.isValid()) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new HashMap<String, Tensor>();
}
return namedGradientsNative(methodName);
}

@DoNotStrip
private native Map<String, Tensor> namedGradientsNative(String methodName);
}
Loading
Loading