55 * This source code is licensed under the BSD-style license found in the
66 * LICENSE file in the root directory of this source tree.
77 */
8- package org.pytorch.executorch
8+
9+ package org.pytorch.executorch.training
910
1011import android.Manifest
1112import android.util.Log
1213import androidx.test.ext.junit.runners.AndroidJUnit4
1314import androidx.test.rule.GrantPermissionRule
14- import java.io.File
15- import java.io.IOException
16- import java.net.URISyntaxException
1715import org.apache.commons.io.FileUtils
1816import org.junit.Assert
1917import org.junit.Rule
2018import org.junit.Test
2119import org.junit.runner.RunWith
22- import org.pytorch.executorch.TestFileUtils.getTestFilePath
23- import org.pytorch.executorch.training.SGD
24- import org.pytorch.executorch.training.TrainingModule
20+ import org.pytorch.executorch.EValue
21+ import org.pytorch.executorch.Tensor
22+ import org.pytorch.executorch.TestFileUtils
23+ import java.io.File
24+ import java.io.IOException
25+ import java.net.URISyntaxException
2526import kotlin.random.Random
2627import kotlin.test.assertContains
2728
@@ -38,17 +39,20 @@ class TrainingModuleE2ETest {
3839 val pteFilePath = " /xor.pte"
3940 val ptdFilePath = " /xor.ptd"
4041
41- val pteFile = File (getTestFilePath(pteFilePath))
42+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
4243 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4344 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
4445 pteInputStream.close()
4546
46- val ptdFile = File (getTestFilePath(ptdFilePath))
47+ val ptdFile = File (TestFileUtils . getTestFilePath(ptdFilePath))
4748 val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4849 FileUtils .copyInputStreamToFile(ptdInputStream, ptdFile)
4950 ptdInputStream.close()
5051
51- val module = TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+ val module = TrainingModule .load(
53+ TestFileUtils .getTestFilePath(pteFilePath),
54+ TestFileUtils .getTestFilePath(ptdFilePath)
55+ )
5256 val params = module.namedParameters(" forward" )
5357
5458 Assert .assertEquals(4 , params.size)
@@ -77,7 +81,10 @@ class TrainingModuleE2ETest {
7781 val targetDex = inputDex + 1
7882 val input = dataset.get(inputDex)
7983 val target = dataset.get(targetDex)
80- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
84+ val out = module.executeForwardBackward(" forward" ,
85+ EValue .from(input),
86+ EValue .from(target)
87+ )
8188 val gradients = module.namedGradients(" forward" )
8289
8390 if (i == 0 ) {
@@ -98,7 +105,9 @@ class TrainingModuleE2ETest {
98105 input.getDataAsFloatArray()[0 ],
99106 input.getDataAsFloatArray()[1 ],
100107 out [1 ].toTensor().getDataAsLongArray()[0 ],
101- target.getDataAsLongArray()[0 ]));
108+ target.getDataAsLongArray()[0 ]
109+ )
110+ );
102111 }
103112
104113 sgd.step(gradients)
@@ -115,12 +124,12 @@ class TrainingModuleE2ETest {
115124 fun testTrainXOR_PTEOnly () {
116125 val pteFilePath = " /xor_full.pte"
117126
118- val pteFile = File (getTestFilePath(pteFilePath))
127+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
119128 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
120129 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
121130 pteInputStream.close()
122131
123- val module = TrainingModule .load(getTestFilePath(pteFilePath));
132+ val module = TrainingModule .load(TestFileUtils . getTestFilePath(pteFilePath));
124133 val params = module.namedParameters(" forward" )
125134
126135 Assert .assertEquals(4 , params.size)
@@ -149,7 +158,10 @@ class TrainingModuleE2ETest {
149158 val targetDex = inputDex + 1
150159 val input = dataset.get(inputDex)
151160 val target = dataset.get(targetDex)
152- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
161+ val out = module.executeForwardBackward(" forward" ,
162+ EValue .from(input),
163+ EValue .from(target)
164+ )
153165 val gradients = module.namedGradients(" forward" )
154166
155167 if (i == 0 ) {
@@ -170,7 +182,9 @@ class TrainingModuleE2ETest {
170182 input.getDataAsFloatArray()[0 ],
171183 input.getDataAsFloatArray()[1 ],
172184 out [1 ].toTensor().getDataAsLongArray()[0 ],
173- target.getDataAsLongArray()[0 ]));
185+ target.getDataAsLongArray()[0 ]
186+ )
187+ );
174188 }
175189
176190 sgd.step(gradients)
@@ -186,24 +200,33 @@ class TrainingModuleE2ETest {
186200 @Throws(IOException ::class )
187201 fun testMissingPteFile () {
188202 val exception = Assert .assertThrows(RuntimeException ::class .java) {
189- TrainingModule .load(getTestFilePath(MISSING_PTE_NAME ))
203+ TrainingModule .load(TestFileUtils . getTestFilePath(MISSING_PTE_NAME ))
190204 }
191- Assert .assertEquals(exception.message, " Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME ))
205+ Assert .assertEquals(
206+ exception.message,
207+ " Cannot load model path!! " + TestFileUtils .getTestFilePath(MISSING_PTE_NAME )
208+ )
192209 }
193210
194211 @Test
195212 @Throws(IOException ::class )
196213 fun testMissingPtdFile () {
197214 val exception = Assert .assertThrows(RuntimeException ::class .java) {
198215 val pteFilePath = " /xor.pte"
199- val pteFile = File (getTestFilePath(pteFilePath))
216+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
200217 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
201218 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
202219 pteInputStream.close()
203220
204- TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME ))
221+ TrainingModule .load(
222+ TestFileUtils .getTestFilePath(pteFilePath),
223+ TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
224+ )
205225 }
206- Assert .assertEquals(exception.message, " Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME ))
226+ Assert .assertEquals(
227+ exception.message,
228+ " Cannot load data path!! " + TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
229+ )
207230 }
208231
209232 companion object {
@@ -214,4 +237,4 @@ class TrainingModuleE2ETest {
214237 private const val MISSING_PTE_NAME = " /missing.pte"
215238 private const val MISSING_PTD_NAME = " /missing.ptd"
216239 }
217- }
240+ }
0 commit comments