diff --git a/extension/android/BUCK b/extension/android/BUCK index 962271d2594..191e6ce4714 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library, "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", - "executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java", - "executorch_android/src/main/java/org/pytorch/executorch/SGD.java", "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", + "executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java", + "executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java", ], autoglob = False, language = "JAVA", diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt similarity index 80% rename from extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt rename to extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt index fe519659f5f..d71cc6aaedd 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt @@ -5,21 +5,24 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch + +package org.pytorch.executorch.training import android.Manifest import android.util.Log import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith -import org.pytorch.executorch.TestFileUtils.getTestFilePath +import org.pytorch.executorch.EValue +import org.pytorch.executorch.Tensor +import org.pytorch.executorch.TestFileUtils +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import kotlin.random.Random import kotlin.test.assertContains @@ -36,17 +39,20 @@ class TrainingModuleE2ETest { val pteFilePath = "/xor.pte" val ptdFilePath = "/xor.ptd" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - val ptdFile = File(getTestFilePath(ptdFilePath)) + val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath)) val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath) FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile) ptdInputStream.close() - val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath)) + val module = TrainingModule.load( + TestFileUtils.getTestFilePath(pteFilePath), + TestFileUtils.getTestFilePath(ptdFilePath) + ) val params = module.namedParameters("forward") Assert.assertEquals(4, params.size) @@ -75,7 +81,10 @@ class TrainingModuleE2ETest { val targetDex = inputDex + 1 val input = dataset.get(inputDex) val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val out = module.executeForwardBackward("forward", + EValue.from(input), + EValue.from(target) + ) val gradients = module.namedGradients("forward") if (i == 0) { @@ -96,7 +105,9 @@ class TrainingModuleE2ETest { input.getDataAsFloatArray()[0], input.getDataAsFloatArray()[1], out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0])); + target.getDataAsLongArray()[0] + ) + ); } sgd.step(gradients) @@ -113,12 +124,12 @@ class TrainingModuleE2ETest { fun testTrainXOR_PTEOnly() { val pteFilePath = "/xor_full.pte" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - val module = TrainingModule.load(getTestFilePath(pteFilePath)); + val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath)); val params = module.namedParameters("forward") Assert.assertEquals(4, params.size) @@ -147,7 +158,10 @@ class TrainingModuleE2ETest { val targetDex = inputDex + 1 val input = dataset.get(inputDex) val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val out = module.executeForwardBackward("forward", + EValue.from(input), + EValue.from(target) + ) val gradients = module.namedGradients("forward") if (i == 0) { @@ -168,7 +182,9 @@ class TrainingModuleE2ETest { input.getDataAsFloatArray()[0], input.getDataAsFloatArray()[1], out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0])); + target.getDataAsLongArray()[0] + ) + ); } sgd.step(gradients) @@ -184,9 +200,12 @@ class TrainingModuleE2ETest { @Throws(IOException::class) fun testMissingPteFile() { val exception = Assert.assertThrows(RuntimeException::class.java) { - TrainingModule.load(getTestFilePath(MISSING_PTE_NAME)) + TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME)) } - Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME)) + Assert.assertEquals( + exception.message, + "Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME) + ) } @Test @@ -194,14 +213,20 @@ class TrainingModuleE2ETest { fun testMissingPtdFile() { val exception = Assert.assertThrows(RuntimeException::class.java) { val pteFilePath = "/xor.pte" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME)) + TrainingModule.load( + TestFileUtils.getTestFilePath(pteFilePath), + TestFileUtils.getTestFilePath(MISSING_PTD_NAME) + ) } - Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME)) + Assert.assertEquals( + exception.message, + "Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME) + ) } companion object { @@ -212,4 +237,4 @@ class TrainingModuleE2ETest { private const val MISSING_PTE_NAME = "/missing.pte" private const val MISSING_PTD_NAME = "/missing.ptd" } -} +} \ No newline at end of file diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java similarity index 95% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 35dbf5cc54c..8f4292c1bc8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -6,13 +6,14 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch; +package org.pytorch.executorch.training; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.util.Map; +import org.pytorch.executorch.Tensor; import org.pytorch.executorch.annotations.Experimental; /** @@ -62,7 +63,7 @@ private SGD( * @param dampening The dampening value * @param weightDecay The weight decay value * @param nesterov Whether to use Nesterov momentum - * @return new {@link org.pytorch.executorch.SGD} object + * @return new {@link SGD} object */ public static SGD create( Map namedParameters, @@ -79,7 +80,7 @@ public static SGD create( * * @param namedParameters Map of parameter names to tensors to be optimized * @param learningRate The learning rate for the optimizer - * @return new {@link org.pytorch.executorch.SGD} object + * @return new {@link SGD} object */ public static SGD create(Map namedParameters, double learningRate) { return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java similarity index 93% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index f3c3cdc1219..3735fb6f426 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch; +package org.pytorch.executorch.training; import android.util.Log; import com.facebook.jni.HybridData; @@ -16,6 +16,8 @@ import java.io.File; import java.util.HashMap; import java.util.Map; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; import org.pytorch.executorch.annotations.Experimental; /** @@ -48,7 +50,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { * * @param modelPath path to file that contains the serialized ExecuTorch module. * @param dataPath path to file that contains the ExecuTorch module external weights. - * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + * @return new {@link TrainingModule} object which owns the model module. */ public static TrainingModule load(final String modelPath, final String dataPath) { File modelFile = new File(modelPath); @@ -67,7 +69,7 @@ public static TrainingModule load(final String modelPath, final String dataPath) * * @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not * rely on external weights. - * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + * @return new {@link TrainingModule} object which owns the model module. */ public static TrainingModule load(final String modelPath) { File modelFile = new File(modelPath); diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 7c66884dcff..5a5e9f24d2f 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -67,7 +67,7 @@ class ExecuTorchTrainingJni public: constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/TrainingModule;"; + "Lorg/pytorch/executorch/training/TrainingModule;"; ExecuTorchTrainingJni( facebook::jni::alias_ref modelPath, @@ -226,7 +226,8 @@ class ExecuTorchTrainingJni class SGDHybrid : public facebook::jni::HybridClass { public: - constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;"; + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/training/SGD;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref,