Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
],
autoglob = False,
language = "JAVA",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
package org.pytorch.executorch

package org.pytorch.executorch.training

import android.Manifest
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.rule.GrantPermissionRule
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import org.apache.commons.io.FileUtils
import org.junit.Assert
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.pytorch.executorch.TestFileUtils.getTestFilePath
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.TestFileUtils
import java.io.File
import java.io.IOException
import java.net.URISyntaxException
import kotlin.random.Random
import kotlin.test.assertContains

Expand All @@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
val pteFilePath = "/xor.pte"
val ptdFilePath = "/xor.ptd"

val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val ptdFile = File(getTestFilePath(ptdFilePath))
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
ptdInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
val module = TrainingModule.load(
TestFileUtils.getTestFilePath(pteFilePath),
TestFileUtils.getTestFilePath(ptdFilePath)
)
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
Expand Down Expand Up @@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val out = module.executeForwardBackward("forward",
EValue.from(input),
EValue.from(target)
)
val gradients = module.namedGradients("forward")

if (i == 0) {
Expand All @@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
target.getDataAsLongArray()[0]
)
);
}

sgd.step(gradients)
Expand All @@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
fun testTrainXOR_PTEOnly() {
val pteFilePath = "/xor_full.pte"

val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

val module = TrainingModule.load(getTestFilePath(pteFilePath));
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
val params = module.namedParameters("forward")

Assert.assertEquals(4, params.size)
Expand Down Expand Up @@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
val targetDex = inputDex + 1
val input = dataset.get(inputDex)
val target = dataset.get(targetDex)
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
val out = module.executeForwardBackward("forward",
EValue.from(input),
EValue.from(target)
)
val gradients = module.namedGradients("forward")

if (i == 0) {
Expand All @@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
input.getDataAsFloatArray()[0],
input.getDataAsFloatArray()[1],
out[1].toTensor().getDataAsLongArray()[0],
target.getDataAsLongArray()[0]));
target.getDataAsLongArray()[0]
)
);
}

sgd.step(gradients)
Expand All @@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
@Throws(IOException::class)
fun testMissingPteFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
}
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
Assert.assertEquals(
exception.message,
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
)
}

@Test
@Throws(IOException::class)
fun testMissingPtdFile() {
val exception = Assert.assertThrows(RuntimeException::class.java) {
val pteFilePath = "/xor.pte"
val pteFile = File(getTestFilePath(pteFilePath))
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
pteInputStream.close()

TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
TrainingModule.load(
TestFileUtils.getTestFilePath(pteFilePath),
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
)
}
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
Assert.assertEquals(
exception.message,
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
)
}

companion object {
Expand All @@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
private const val MISSING_PTE_NAME = "/missing.pte"
private const val MISSING_PTD_NAME = "/missing.ptd"
}
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* * Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind the extra layer of * for the license? This is also done for the TrainingModule, so just wanted to understand the distinction.

* * All rights reserved.
* *
* * This source code is licensed under the BSD-style license found in the
* * LICENSE file in the root directory of this source tree.
*
*
*/

package org.pytorch.executorch;
package org.pytorch.executorch.training;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.util.Map;
import org.pytorch.executorch.Tensor;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand Down Expand Up @@ -62,7 +66,7 @@ private SGD(
* @param dampening The dampening value
* @param weightDecay The weight decay value
* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
* @return new {@link SGD} object
*/
public static SGD create(
Map<String, Tensor> namedParameters,
Expand All @@ -79,7 +83,7 @@ public static SGD create(
*
* @param namedParameters Map of parameter names to tensors to be optimized
* @param learningRate The learning rate for the optimizer
* @return new {@link org.pytorch.executorch.SGD} object
* @return new {@link SGD} object
*/
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* * Copyright (c) Meta Platforms, Inc. and affiliates.
* * All rights reserved.
* *
* * This source code is licensed under the BSD-style license found in the
* * LICENSE file in the root directory of this source tree.
*
*
*/

package org.pytorch.executorch;
package org.pytorch.executorch.training;

import android.util.Log;
import com.facebook.jni.HybridData;
Expand All @@ -16,6 +19,8 @@
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import org.pytorch.executorch.EValue;
import org.pytorch.executorch.Tensor;
import org.pytorch.executorch.annotations.Experimental;

/**
Expand Down Expand Up @@ -48,7 +53,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
*
* @param modelPath path to file that contains the serialized ExecuTorch module.
* @param dataPath path to file that contains the ExecuTorch module external weights.
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
* @return new {@link TrainingModule} object which owns the model module.
*/
public static TrainingModule load(final String modelPath, final String dataPath) {
File modelFile = new File(modelPath);
Expand All @@ -67,7 +72,7 @@ public static TrainingModule load(final String modelPath, final String dataPath)
*
* @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not
* rely on external weights.
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
* @return new {@link TrainingModule} object which owns the model module.
*/
public static TrainingModule load(final String modelPath) {
File modelFile = new File(modelPath);
Expand Down
5 changes: 3 additions & 2 deletions extension/android/jni/jni_layer_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ExecuTorchTrainingJni

public:
constexpr static auto kJavaDescriptor =
"Lorg/pytorch/executorch/TrainingModule;";
"Lorg/pytorch/executorch/training/TrainingModule;";

ExecuTorchTrainingJni(
facebook::jni::alias_ref<jstring> modelPath,
Expand Down Expand Up @@ -226,7 +226,8 @@ class ExecuTorchTrainingJni

class SGDHybrid : public facebook::jni::HybridClass<SGDHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;";
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/training/SGD;";

static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
Expand Down
Loading