From b9c72bf03773c70879f63d2a7a5142b2524209b8 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 31 Jul 2025 11:45:14 -0700 Subject: [PATCH 1/3] [Android] Move training part to its own package --- .../executorch/TrainingModuleE2ETest.kt | 2 ++ .../executorch/{ => training}/SGD.java | 18 +++++++++++------- .../{ => training}/TrainingModule.java | 19 ++++++++++++------- extension/android/jni/jni_layer_training.cpp | 5 +++-- 4 files changed, 28 insertions(+), 16 deletions(-) rename extension/android/executorch_android/src/main/java/org/pytorch/executorch/{ => training}/SGD.java (87%) rename extension/android/executorch_android/src/main/java/org/pytorch/executorch/{ => training}/TrainingModule.java (88%) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt index fe519659f5f..e4cd96bcafb 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt @@ -20,6 +20,8 @@ import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath +import org.pytorch.executorch.training.SGD +import org.pytorch.executorch.training.TrainingModule import kotlin.random.Random import kotlin.test.assertContains diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java similarity index 87% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 35dbf5cc54c..535f084e9ad 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -1,18 +1,22 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. + * * Copyright (c) Meta Platforms, Inc. and affiliates. + * * All rights reserved. + * * + * * This source code is licensed under the BSD-style license found in the + * * LICENSE file in the root directory of this source tree. + * + * */ -package org.pytorch.executorch; +package org.pytorch.executorch.training; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.util.Map; +import org.pytorch.executorch.Tensor; import org.pytorch.executorch.annotations.Experimental; /** @@ -62,7 +66,7 @@ private SGD( * @param dampening The dampening value * @param weightDecay The weight decay value * @param nesterov Whether to use Nesterov momentum - * @return new {@link org.pytorch.executorch.SGD} object + * @return new {@link SGD} object */ public static SGD create( Map namedParameters, @@ -79,7 +83,7 @@ public static SGD create( * * @param namedParameters Map of parameter names to tensors to be optimized * @param learningRate The learning rate for the optimizer - * @return new {@link org.pytorch.executorch.SGD} object + * @return new {@link SGD} object */ public static SGD create(Map namedParameters, double learningRate) { return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java similarity index 88% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index f3c3cdc1219..95b4fe466d6 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -1,12 +1,15 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. + * * Copyright (c) Meta Platforms, Inc. and affiliates. + * * All rights reserved. + * * + * * This source code is licensed under the BSD-style license found in the + * * LICENSE file in the root directory of this source tree. + * + * */ -package org.pytorch.executorch; +package org.pytorch.executorch.training; import android.util.Log; import com.facebook.jni.HybridData; @@ -16,6 +19,8 @@ import java.io.File; import java.util.HashMap; import java.util.Map; +import org.pytorch.executorch.EValue; +import org.pytorch.executorch.Tensor; import org.pytorch.executorch.annotations.Experimental; /** @@ -48,7 +53,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { * * @param modelPath path to file that contains the serialized ExecuTorch module. * @param dataPath path to file that contains the ExecuTorch module external weights. - * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + * @return new {@link TrainingModule} object which owns the model module. */ public static TrainingModule load(final String modelPath, final String dataPath) { File modelFile = new File(modelPath); @@ -67,7 +72,7 @@ public static TrainingModule load(final String modelPath, final String dataPath) * * @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not * rely on external weights. - * @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module. + * @return new {@link TrainingModule} object which owns the model module. */ public static TrainingModule load(final String modelPath) { File modelFile = new File(modelPath); diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 7c66884dcff..5a5e9f24d2f 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -67,7 +67,7 @@ class ExecuTorchTrainingJni public: constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/TrainingModule;"; + "Lorg/pytorch/executorch/training/TrainingModule;"; ExecuTorchTrainingJni( facebook::jni::alias_ref modelPath, @@ -226,7 +226,8 @@ class ExecuTorchTrainingJni class SGDHybrid : public facebook::jni::HybridClass { public: - constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;"; + constexpr static const char* kJavaDescriptor = + "Lorg/pytorch/executorch/training/SGD;"; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, From 3395ed54c035034c56cb9ee1d17719b6dbfe9891 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 31 Jul 2025 13:44:49 -0700 Subject: [PATCH 2/3] BUCK update --- extension/android/BUCK | 4 +- .../{ => training}/TrainingModuleE2ETest.kt | 67 +++++++++++++------ 2 files changed, 47 insertions(+), 24 deletions(-) rename extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/{ => training}/TrainingModuleE2ETest.kt (80%) diff --git a/extension/android/BUCK b/extension/android/BUCK index 962271d2594..191e6ce4714 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library, "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", - "executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java", - "executorch_android/src/main/java/org/pytorch/executorch/SGD.java", "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", + "executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java", + "executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java", ], autoglob = False, language = "JAVA", diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt similarity index 80% rename from extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt rename to extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt index e4cd96bcafb..d71cc6aaedd 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt @@ -5,23 +5,24 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch + +package org.pytorch.executorch.training import android.Manifest import android.util.Log import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule -import java.io.File -import java.io.IOException -import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith -import org.pytorch.executorch.TestFileUtils.getTestFilePath -import org.pytorch.executorch.training.SGD -import org.pytorch.executorch.training.TrainingModule +import org.pytorch.executorch.EValue +import org.pytorch.executorch.Tensor +import org.pytorch.executorch.TestFileUtils +import java.io.File +import java.io.IOException +import java.net.URISyntaxException import kotlin.random.Random import kotlin.test.assertContains @@ -38,17 +39,20 @@ class TrainingModuleE2ETest { val pteFilePath = "/xor.pte" val ptdFilePath = "/xor.ptd" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - val ptdFile = File(getTestFilePath(ptdFilePath)) + val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath)) val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath) FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile) ptdInputStream.close() - val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath)) + val module = TrainingModule.load( + TestFileUtils.getTestFilePath(pteFilePath), + TestFileUtils.getTestFilePath(ptdFilePath) + ) val params = module.namedParameters("forward") Assert.assertEquals(4, params.size) @@ -77,7 +81,10 @@ class TrainingModuleE2ETest { val targetDex = inputDex + 1 val input = dataset.get(inputDex) val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val out = module.executeForwardBackward("forward", + EValue.from(input), + EValue.from(target) + ) val gradients = module.namedGradients("forward") if (i == 0) { @@ -98,7 +105,9 @@ class TrainingModuleE2ETest { input.getDataAsFloatArray()[0], input.getDataAsFloatArray()[1], out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0])); + target.getDataAsLongArray()[0] + ) + ); } sgd.step(gradients) @@ -115,12 +124,12 @@ class TrainingModuleE2ETest { fun testTrainXOR_PTEOnly() { val pteFilePath = "/xor_full.pte" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - val module = TrainingModule.load(getTestFilePath(pteFilePath)); + val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath)); val params = module.namedParameters("forward") Assert.assertEquals(4, params.size) @@ -149,7 +158,10 @@ class TrainingModuleE2ETest { val targetDex = inputDex + 1 val input = dataset.get(inputDex) val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val out = module.executeForwardBackward("forward", + EValue.from(input), + EValue.from(target) + ) val gradients = module.namedGradients("forward") if (i == 0) { @@ -170,7 +182,9 @@ class TrainingModuleE2ETest { input.getDataAsFloatArray()[0], input.getDataAsFloatArray()[1], out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0])); + target.getDataAsLongArray()[0] + ) + ); } sgd.step(gradients) @@ -186,9 +200,12 @@ class TrainingModuleE2ETest { @Throws(IOException::class) fun testMissingPteFile() { val exception = Assert.assertThrows(RuntimeException::class.java) { - TrainingModule.load(getTestFilePath(MISSING_PTE_NAME)) + TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME)) } - Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME)) + Assert.assertEquals( + exception.message, + "Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME) + ) } @Test @@ -196,14 +213,20 @@ class TrainingModuleE2ETest { fun testMissingPtdFile() { val exception = Assert.assertThrows(RuntimeException::class.java) { val pteFilePath = "/xor.pte" - val pteFile = File(getTestFilePath(pteFilePath)) + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) val pteInputStream = javaClass.getResourceAsStream(pteFilePath) FileUtils.copyInputStreamToFile(pteInputStream, pteFile) pteInputStream.close() - TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME)) + TrainingModule.load( + TestFileUtils.getTestFilePath(pteFilePath), + TestFileUtils.getTestFilePath(MISSING_PTD_NAME) + ) } - Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME)) + Assert.assertEquals( + exception.message, + "Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME) + ) } companion object { @@ -214,4 +237,4 @@ class TrainingModuleE2ETest { private const val MISSING_PTE_NAME = "/missing.pte" private const val MISSING_PTD_NAME = "/missing.ptd" } -} +} \ No newline at end of file From 60c1f8fe3980d1505df4485a102ea76c61079cbf Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 31 Jul 2025 14:04:05 -0700 Subject: [PATCH 3/3] format --- .../java/org/pytorch/executorch/training/SGD.java | 11 ++++------- .../pytorch/executorch/training/TrainingModule.java | 11 ++++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 535f084e9ad..8f4292c1bc8 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -1,12 +1,9 @@ /* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. * - * * Copyright (c) Meta Platforms, Inc. and affiliates. - * * All rights reserved. - * * - * * This source code is licensed under the BSD-style license found in the - * * LICENSE file in the root directory of this source tree. - * - * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. */ package org.pytorch.executorch.training; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index 95b4fe466d6..3735fb6f426 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -1,12 +1,9 @@ /* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. * - * * Copyright (c) Meta Platforms, Inc. and affiliates. - * * All rights reserved. - * * - * * This source code is licensed under the BSD-style license found in the - * * LICENSE file in the root directory of this source tree. - * - * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. */ package org.pytorch.executorch.training;