Skip to content

Commit 3395ed5

Browse files
committed
BUCK update
1 parent b9c72bf commit 3395ed5

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

extension/android/BUCK

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ non_fbcode_target(_kind = fb_android_library,
1313
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1414
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1515
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
16-
"executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java",
17-
"executorch_android/src/main/java/org/pytorch/executorch/SGD.java",
1816
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
17+
"executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java",
18+
"executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java",
1919
],
2020
autoglob = False,
2121
language = "JAVA",
Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,24 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
package org.pytorch.executorch
8+
9+
package org.pytorch.executorch.training
910

1011
import android.Manifest
1112
import android.util.Log
1213
import androidx.test.ext.junit.runners.AndroidJUnit4
1314
import androidx.test.rule.GrantPermissionRule
14-
import java.io.File
15-
import java.io.IOException
16-
import java.net.URISyntaxException
1715
import org.apache.commons.io.FileUtils
1816
import org.junit.Assert
1917
import org.junit.Rule
2018
import org.junit.Test
2119
import org.junit.runner.RunWith
22-
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23-
import org.pytorch.executorch.training.SGD
24-
import org.pytorch.executorch.training.TrainingModule
20+
import org.pytorch.executorch.EValue
21+
import org.pytorch.executorch.Tensor
22+
import org.pytorch.executorch.TestFileUtils
23+
import java.io.File
24+
import java.io.IOException
25+
import java.net.URISyntaxException
2526
import kotlin.random.Random
2627
import kotlin.test.assertContains
2728

@@ -38,17 +39,20 @@ class TrainingModuleE2ETest {
3839
val pteFilePath = "/xor.pte"
3940
val ptdFilePath = "/xor.ptd"
4041

41-
val pteFile = File(getTestFilePath(pteFilePath))
42+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
4243
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4344
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
4445
pteInputStream.close()
4546

46-
val ptdFile = File(getTestFilePath(ptdFilePath))
47+
val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath))
4748
val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4849
FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile)
4950
ptdInputStream.close()
5051

51-
val module = TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+
val module = TrainingModule.load(
53+
TestFileUtils.getTestFilePath(pteFilePath),
54+
TestFileUtils.getTestFilePath(ptdFilePath)
55+
)
5256
val params = module.namedParameters("forward")
5357

5458
Assert.assertEquals(4, params.size)
@@ -77,7 +81,10 @@ class TrainingModuleE2ETest {
7781
val targetDex = inputDex + 1
7882
val input = dataset.get(inputDex)
7983
val target = dataset.get(targetDex)
80-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
84+
val out = module.executeForwardBackward("forward",
85+
EValue.from(input),
86+
EValue.from(target)
87+
)
8188
val gradients = module.namedGradients("forward")
8289

8390
if (i == 0) {
@@ -98,7 +105,9 @@ class TrainingModuleE2ETest {
98105
input.getDataAsFloatArray()[0],
99106
input.getDataAsFloatArray()[1],
100107
out[1].toTensor().getDataAsLongArray()[0],
101-
target.getDataAsLongArray()[0]));
108+
target.getDataAsLongArray()[0]
109+
)
110+
);
102111
}
103112

104113
sgd.step(gradients)
@@ -115,12 +124,12 @@ class TrainingModuleE2ETest {
115124
fun testTrainXOR_PTEOnly() {
116125
val pteFilePath = "/xor_full.pte"
117126

118-
val pteFile = File(getTestFilePath(pteFilePath))
127+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
119128
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
120129
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
121130
pteInputStream.close()
122131

123-
val module = TrainingModule.load(getTestFilePath(pteFilePath));
132+
val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath));
124133
val params = module.namedParameters("forward")
125134

126135
Assert.assertEquals(4, params.size)
@@ -149,7 +158,10 @@ class TrainingModuleE2ETest {
149158
val targetDex = inputDex + 1
150159
val input = dataset.get(inputDex)
151160
val target = dataset.get(targetDex)
152-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
161+
val out = module.executeForwardBackward("forward",
162+
EValue.from(input),
163+
EValue.from(target)
164+
)
153165
val gradients = module.namedGradients("forward")
154166

155167
if (i == 0) {
@@ -170,7 +182,9 @@ class TrainingModuleE2ETest {
170182
input.getDataAsFloatArray()[0],
171183
input.getDataAsFloatArray()[1],
172184
out[1].toTensor().getDataAsLongArray()[0],
173-
target.getDataAsLongArray()[0]));
185+
target.getDataAsLongArray()[0]
186+
)
187+
);
174188
}
175189

176190
sgd.step(gradients)
@@ -186,24 +200,33 @@ class TrainingModuleE2ETest {
186200
@Throws(IOException::class)
187201
fun testMissingPteFile() {
188202
val exception = Assert.assertThrows(RuntimeException::class.java) {
189-
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
203+
TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME))
190204
}
191-
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
205+
Assert.assertEquals(
206+
exception.message,
207+
"Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME)
208+
)
192209
}
193210

194211
@Test
195212
@Throws(IOException::class)
196213
fun testMissingPtdFile() {
197214
val exception = Assert.assertThrows(RuntimeException::class.java) {
198215
val pteFilePath = "/xor.pte"
199-
val pteFile = File(getTestFilePath(pteFilePath))
216+
val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath))
200217
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
201218
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
202219
pteInputStream.close()
203220

204-
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
221+
TrainingModule.load(
222+
TestFileUtils.getTestFilePath(pteFilePath),
223+
TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
224+
)
205225
}
206-
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
226+
Assert.assertEquals(
227+
exception.message,
228+
"Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME)
229+
)
207230
}
208231

209232
companion object {
@@ -214,4 +237,4 @@ class TrainingModuleE2ETest {
214237
private const val MISSING_PTE_NAME = "/missing.pte"
215238
private const val MISSING_PTD_NAME = "/missing.ptd"
216239
}
217-
}
240+
}

0 commit comments

Comments
 (0)