55 * This source code is licensed under the BSD-style license found in the
66 * LICENSE file in the root directory of this source tree.
77 */
8- package org.pytorch.executorch
8+
9+ package org.pytorch.executorch.training
910
1011import android.Manifest
1112import android.util.Log
1213import androidx.test.ext.junit.runners.AndroidJUnit4
1314import androidx.test.rule.GrantPermissionRule
14- import java.io.File
15- import java.io.IOException
16- import java.net.URISyntaxException
1715import org.apache.commons.io.FileUtils
1816import org.junit.Assert
1917import org.junit.Rule
2018import org.junit.Test
2119import org.junit.runner.RunWith
22- import org.pytorch.executorch.TestFileUtils.getTestFilePath
20+ import org.pytorch.executorch.EValue
21+ import org.pytorch.executorch.Tensor
22+ import org.pytorch.executorch.TestFileUtils
23+ import java.io.File
24+ import java.io.IOException
25+ import java.net.URISyntaxException
2326import kotlin.random.Random
2427import kotlin.test.assertContains
2528
@@ -36,17 +39,20 @@ class TrainingModuleE2ETest {
3639 val pteFilePath = " /xor.pte"
3740 val ptdFilePath = " /xor.ptd"
3841
39- val pteFile = File (getTestFilePath(pteFilePath))
42+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
4043 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
4144 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
4245 pteInputStream.close()
4346
44- val ptdFile = File (getTestFilePath(ptdFilePath))
47+ val ptdFile = File (TestFileUtils . getTestFilePath(ptdFilePath))
4548 val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath)
4649 FileUtils .copyInputStreamToFile(ptdInputStream, ptdFile)
4750 ptdInputStream.close()
4851
49- val module = TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(ptdFilePath))
52+ val module = TrainingModule .load(
53+ TestFileUtils .getTestFilePath(pteFilePath),
54+ TestFileUtils .getTestFilePath(ptdFilePath)
55+ )
5056 val params = module.namedParameters(" forward" )
5157
5258 Assert .assertEquals(4 , params.size)
@@ -75,7 +81,10 @@ class TrainingModuleE2ETest {
7581 val targetDex = inputDex + 1
7682 val input = dataset.get(inputDex)
7783 val target = dataset.get(targetDex)
78- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
84+ val out = module.executeForwardBackward(" forward" ,
85+ EValue .from(input),
86+ EValue .from(target)
87+ )
7988 val gradients = module.namedGradients(" forward" )
8089
8190 if (i == 0 ) {
@@ -96,7 +105,9 @@ class TrainingModuleE2ETest {
96105 input.getDataAsFloatArray()[0 ],
97106 input.getDataAsFloatArray()[1 ],
98107 out [1 ].toTensor().getDataAsLongArray()[0 ],
99- target.getDataAsLongArray()[0 ]));
108+ target.getDataAsLongArray()[0 ]
109+ )
110+ );
100111 }
101112
102113 sgd.step(gradients)
@@ -113,12 +124,12 @@ class TrainingModuleE2ETest {
113124 fun testTrainXOR_PTEOnly () {
114125 val pteFilePath = " /xor_full.pte"
115126
116- val pteFile = File (getTestFilePath(pteFilePath))
127+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
117128 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
118129 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
119130 pteInputStream.close()
120131
121- val module = TrainingModule .load(getTestFilePath(pteFilePath));
132+ val module = TrainingModule .load(TestFileUtils . getTestFilePath(pteFilePath));
122133 val params = module.namedParameters(" forward" )
123134
124135 Assert .assertEquals(4 , params.size)
@@ -147,7 +158,10 @@ class TrainingModuleE2ETest {
147158 val targetDex = inputDex + 1
148159 val input = dataset.get(inputDex)
149160 val target = dataset.get(targetDex)
150- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
161+ val out = module.executeForwardBackward(" forward" ,
162+ EValue .from(input),
163+ EValue .from(target)
164+ )
151165 val gradients = module.namedGradients(" forward" )
152166
153167 if (i == 0 ) {
@@ -168,7 +182,9 @@ class TrainingModuleE2ETest {
168182 input.getDataAsFloatArray()[0 ],
169183 input.getDataAsFloatArray()[1 ],
170184 out [1 ].toTensor().getDataAsLongArray()[0 ],
171- target.getDataAsLongArray()[0 ]));
185+ target.getDataAsLongArray()[0 ]
186+ )
187+ );
172188 }
173189
174190 sgd.step(gradients)
@@ -184,24 +200,33 @@ class TrainingModuleE2ETest {
184200 @Throws(IOException ::class )
185201 fun testMissingPteFile () {
186202 val exception = Assert .assertThrows(RuntimeException ::class .java) {
187- TrainingModule .load(getTestFilePath(MISSING_PTE_NAME ))
203+ TrainingModule .load(TestFileUtils . getTestFilePath(MISSING_PTE_NAME ))
188204 }
189- Assert .assertEquals(exception.message, " Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME ))
205+ Assert .assertEquals(
206+ exception.message,
207+ " Cannot load model path!! " + TestFileUtils .getTestFilePath(MISSING_PTE_NAME )
208+ )
190209 }
191210
192211 @Test
193212 @Throws(IOException ::class )
194213 fun testMissingPtdFile () {
195214 val exception = Assert .assertThrows(RuntimeException ::class .java) {
196215 val pteFilePath = " /xor.pte"
197- val pteFile = File (getTestFilePath(pteFilePath))
216+ val pteFile = File (TestFileUtils . getTestFilePath(pteFilePath))
198217 val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199218 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
200219 pteInputStream.close()
201220
202- TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME ))
221+ TrainingModule .load(
222+ TestFileUtils .getTestFilePath(pteFilePath),
223+ TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
224+ )
203225 }
204- Assert .assertEquals(exception.message, " Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME ))
226+ Assert .assertEquals(
227+ exception.message,
228+ " Cannot load data path!! " + TestFileUtils .getTestFilePath(MISSING_PTD_NAME )
229+ )
205230 }
206231
207232 companion object {
@@ -212,4 +237,4 @@ class TrainingModuleE2ETest {
212237 private const val MISSING_PTE_NAME = " /missing.pte"
213238 private const val MISSING_PTD_NAME = " /missing.ptd"
214239 }
215- }
240+ }
0 commit comments