Skip to content

Commit b9c72bf

Browse files
committed
[Android] Move training part to its own package
1 parent 439bb6c commit b9c72bf

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import org.junit.Rule
2020
import org.junit.Test
2121
import org.junit.runner.RunWith
2222
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23+
import org.pytorch.executorch.training.SGD
24+
import org.pytorch.executorch.training.TrainingModule
2325
import kotlin.random.Random
2426
import kotlin.test.assertContains
2527

extension/android/executorch_android/src/main/java/org/pytorch/executorch/SGD.java renamed to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
/*
2-
* Copyright (c) Meta Platforms, Inc. and affiliates.
3-
* All rights reserved.
42
*
5-
* This source code is licensed under the BSD-style license found in the
6-
* LICENSE file in the root directory of this source tree.
3+
* * Copyright (c) Meta Platforms, Inc. and affiliates.
4+
* * All rights reserved.
5+
* *
6+
* * This source code is licensed under the BSD-style license found in the
7+
* * LICENSE file in the root directory of this source tree.
8+
*
9+
*
710
*/
811

9-
package org.pytorch.executorch;
12+
package org.pytorch.executorch.training;
1013

1114
import com.facebook.jni.HybridData;
1215
import com.facebook.jni.annotations.DoNotStrip;
1316
import com.facebook.soloader.nativeloader.NativeLoader;
1417
import com.facebook.soloader.nativeloader.SystemDelegate;
1518
import java.util.Map;
19+
import org.pytorch.executorch.Tensor;
1620
import org.pytorch.executorch.annotations.Experimental;
1721

1822
/**
@@ -62,7 +66,7 @@ private SGD(
6266
* @param dampening The dampening value
6367
* @param weightDecay The weight decay value
6468
* @param nesterov Whether to use Nesterov momentum
65-
* @return new {@link org.pytorch.executorch.SGD} object
69+
* @return new {@link SGD} object
6670
*/
6771
public static SGD create(
6872
Map<String, Tensor> namedParameters,
@@ -79,7 +83,7 @@ public static SGD create(
7983
*
8084
* @param namedParameters Map of parameter names to tensors to be optimized
8185
* @param learningRate The learning rate for the optimizer
82-
* @return new {@link org.pytorch.executorch.SGD} object
86+
* @return new {@link SGD} object
8387
*/
8488
public static SGD create(Map<String, Tensor> namedParameters, double learningRate) {
8589
return create(namedParameters, learningRate, 0.0, 0.0, 0.0, false);

extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java renamed to extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
/*
2-
* Copyright (c) Meta Platforms, Inc. and affiliates.
3-
* All rights reserved.
42
*
5-
* This source code is licensed under the BSD-style license found in the
6-
* LICENSE file in the root directory of this source tree.
3+
* * Copyright (c) Meta Platforms, Inc. and affiliates.
4+
* * All rights reserved.
5+
* *
6+
* * This source code is licensed under the BSD-style license found in the
7+
* * LICENSE file in the root directory of this source tree.
8+
*
9+
*
710
*/
811

9-
package org.pytorch.executorch;
12+
package org.pytorch.executorch.training;
1013

1114
import android.util.Log;
1215
import com.facebook.jni.HybridData;
@@ -16,6 +19,8 @@
1619
import java.io.File;
1720
import java.util.HashMap;
1821
import java.util.Map;
22+
import org.pytorch.executorch.EValue;
23+
import org.pytorch.executorch.Tensor;
1924
import org.pytorch.executorch.annotations.Experimental;
2025

2126
/**
@@ -48,7 +53,7 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
4853
*
4954
* @param modelPath path to file that contains the serialized ExecuTorch module.
5055
* @param dataPath path to file that contains the ExecuTorch module external weights.
51-
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
56+
* @return new {@link TrainingModule} object which owns the model module.
5257
*/
5358
public static TrainingModule load(final String modelPath, final String dataPath) {
5459
File modelFile = new File(modelPath);
@@ -67,7 +72,7 @@ public static TrainingModule load(final String modelPath, final String dataPath)
6772
*
6873
* @param modelPath path to file that contains the serialized ExecuTorch module. This PTE does not
6974
* rely on external weights.
70-
* @return new {@link org.pytorch.executorch.TrainingModule} object which owns the model module.
75+
* @return new {@link TrainingModule} object which owns the model module.
7176
*/
7277
public static TrainingModule load(final String modelPath) {
7378
File modelFile = new File(modelPath);

extension/android/jni/jni_layer_training.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class ExecuTorchTrainingJni
6767

6868
public:
6969
constexpr static auto kJavaDescriptor =
70-
"Lorg/pytorch/executorch/TrainingModule;";
70+
"Lorg/pytorch/executorch/training/TrainingModule;";
7171

7272
ExecuTorchTrainingJni(
7373
facebook::jni::alias_ref<jstring> modelPath,
@@ -226,7 +226,8 @@ class ExecuTorchTrainingJni
226226

227227
class SGDHybrid : public facebook::jni::HybridClass<SGDHybrid> {
228228
public:
229-
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/SGD;";
229+
constexpr static const char* kJavaDescriptor =
230+
"Lorg/pytorch/executorch/training/SGD;";
230231

231232
static facebook::jni::local_ref<jhybriddata> initHybrid(
232233
facebook::jni::alias_ref<jclass>,

0 commit comments

Comments
 (0)