Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
],
autoglob = False,
language = "JAVA",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.pytorch.executorch

package org.pytorch.executorch.training

import android.Manifest
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.rule.GrantPermissionRule
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.TestFileUtils
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import kotlin.random.Random
import kotlin.test.assertContains

Expand All @@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
val pteFilePath = "/xor.pte"
val ptdFilePath = "/xor.ptd"

val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val ptdFile = File(getTestFilePath(ptdFilePath))
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
ptdInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
val module = TrainingModule.load(
TestFileUtils.getTestFilePath(pteFilePath),
TestFileUtils.getTestFilePath(ptdFilePath)
)
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
Expand Down Expand Up @@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val out = module.executeForwardBackward("forward",
EValue.from(input),
EValue.from(target)
)
val gradients = module.namedGradients("forward")

if (i == 0) {
Expand All @@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
target.getDataAsLongArray()[0]
)
);
}

sgd.step(gradients)
Expand All @@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
fun testTrainXOR_PTEOnly() {
val pteFilePath = "/xor_full.pte"

val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath));
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
Expand Down Expand Up @@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val out = module.executeForwardBackward("forward",
EValue.from(input),
EValue.from(target)
)
val gradients = module.namedGradients("forward")

if (i == 0) {
Expand All @@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
target.getDataAsLongArray()[0]
)
);
}

sgd.step(gradients)
Expand All @@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
@Throws(IOException::class)
fun testMissingPteFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
}
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
Assert.assertEquals(
exception.message,
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
)
}

@Test
@Throws(IOException::class)
fun testMissingPtdFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
val pteFilePath = "/xor.pte"
val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
TrainingModule.load(
TestFileUtils.getTestFilePath(pteFilePath),
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
)
}
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
Assert.assertEquals(
exception.message,
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
)
}

companion object {
Expand All @@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
private const val MISSING_PTE_NAME = "/missing.pte"
private const val MISSING_PTD_NAME = "/missing.ptd"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;
package org.pytorch.executorch.training;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
import org.pytorch.executorch.Tensor;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand Down Expand Up @@ -62,7 +63,7 @@ private SGD(
* @param dampening The dampening value
* @param weightDecay The weight decay value
* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
* @return new {@link SGD} object
*/
public static SGD create(
Map<String, Tensor> namedParameters,
Expand All @@ -79,7 +80,7 @@ public static SGD create(
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @return new {@link org.pytorch.executorch.SGD} object
* @return new {@link SGD} object
*/
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;
package org.pytorch.executorch.training;

import android.util.Log;
import com.facebook.jni.HybridData;
Expand All @@ -16,6 +16,8 @@
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import org.pytorch.executorch.EValue;
import org.pytorch.executorch.Tensor;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand Down Expand Up @@ -48,7 +50,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
*
* @param modelPath path to file that contains the serialized ExecuTorch module.
* @param dataPath path to file that contains the ExecuTorch module external weights.
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
* @return new {@link TrainingModule} object which owns the model module.
*/
public static TrainingModule load(final String modelPath, final String dataPath) {
File modelFile = new File(modelPath);
Expand All @@ -67,7 +69,7 @@ public static TrainingModule load(final String modelPath, final String dataPath)
*
* @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not
* rely on external weights.
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
* @return new {@link TrainingModule} object which owns the model module.
*/
public static TrainingModule load(final String modelPath) {
File modelFile = new File(modelPath);
Expand Down
5 changes: 3 additions & 2 deletions extension/android/jni/jni_layer_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ExecuTorchTrainingJni

public:
constexpr static auto kJavaDescriptor =
"Lorg/pytorch/executorch/TrainingModule;";
"Lorg/pytorch/executorch/training/TrainingModule;";

ExecuTorchTrainingJni(
facebook::jni::alias_ref<jstring> modelPath,
Expand Down Expand Up @@ -226,7 +226,8 @@ class ExecuTorchTrainingJni

class SGDHybrid : public facebook::jni::HybridClass<SGDHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;";
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/training/SGD;";

static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
Expand Down
Loading