@@ -11,17 +11,18 @@ import android.Manifest
1111import android.util.Log
1212import androidx.test.ext.junit.runners.AndroidJUnit4
1313import androidx.test.rule.GrantPermissionRule
14+ import java.io.ByteArrayInputStream
1415import java.io.File
1516import java.io.IOException
1617import java.net.URISyntaxException
18+ import kotlin.random.Random
19+ import kotlin.test.assertContains
1720import org.apache.commons.io.FileUtils
1821import org.junit.Assert
1922import org.junit.Rule
2023import org.junit.Test
2124import org.junit.runner.RunWith
2225import org.pytorch.executorch.TestFileUtils.getTestFilePath
23- import kotlin.random.Random
24- import kotlin.test.assertContains
2526
2627/* * Unit tests for [TrainingModule]. */
2728@RunWith(AndroidJUnit4 ::class )
@@ -55,27 +56,29 @@ class TrainingModuleE2ETest {
5556 assertContains(params, LIN2_WEIGHT )
5657 assertContains(params, LIN2_BIAS )
5758
58- val sgd = SGD .create(params, 0.5 );
59- val dataset = listOf<Tensor >(
60- Tensor .fromBlob(floatArrayOf(1.0f , 1.0f ), longArrayOf(1 , 2 )),
61- Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
62- Tensor .fromBlob(floatArrayOf(0.0f , 0.0f ), longArrayOf(1 , 2 )),
63- Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
64- Tensor .fromBlob(floatArrayOf(1.0f , 0.0f ), longArrayOf(1 , 2 )),
65- Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
66- Tensor .fromBlob(floatArrayOf(0.0f , 1.0f ), longArrayOf(1 , 2 )),
67- Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
68- )
59+ val sgd = SGD .create(params, 0.5 )
60+ val dataset =
61+ listOf<Tensor >(
62+ Tensor .fromBlob(floatArrayOf(1.0f , 1.0f ), longArrayOf(1 , 2 )),
63+ Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
64+ Tensor .fromBlob(floatArrayOf(0.0f , 0.0f ), longArrayOf(1 , 2 )),
65+ Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
66+ Tensor .fromBlob(floatArrayOf(1.0f , 0.0f ), longArrayOf(1 , 2 )),
67+ Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
68+ Tensor .fromBlob(floatArrayOf(0.0f , 1.0f ), longArrayOf(1 , 2 )),
69+ Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
70+ )
6971
70- val numEpochs = 5000 ;
72+ val numEpochs = 5000
7173 var finalLoss = Float .MAX_VALUE
7274
7375 for (i in 0 until numEpochs) {
7476 val inputDex = 2 * Random .nextInt(dataset.size / 2 )
7577 val targetDex = inputDex + 1
7678 val input = dataset.get(inputDex)
7779 val target = dataset.get(targetDex)
78- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
80+ val out =
81+ module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
7982 val gradients = module.namedGradients(" forward" )
8083
8184 if (i == 0 ) {
@@ -96,7 +99,9 @@ class TrainingModuleE2ETest {
9699 input.getDataAsFloatArray()[0 ],
97100 input.getDataAsFloatArray()[1 ],
98101 out [1 ].toTensor().getDataAsLongArray()[0 ],
99- target.getDataAsLongArray()[0 ]));
102+ target.getDataAsLongArray()[0 ],
103+ ),
104+ )
100105 }
101106
102107 sgd.step(gradients)
@@ -106,6 +111,34 @@ class TrainingModuleE2ETest {
106111 }
107112 }
108113 Assert .assertTrue(finalLoss < 0.1f )
114+
115+ // Check training performance continuity when exporting and loading from PTD checkpoint.
116+ val checkpoint = module.exportWeights(" forward" )
117+ val bytes = ByteArray (checkpoint.remaining())
118+ checkpoint.duplicate().get(bytes)
119+
120+ val ptdCheckpointFilePath = " /xor_checkpoint.ptd"
121+ val ptdCheckpointFile = File (getTestFilePath(ptdCheckpointFilePath))
122+ val checkpointInputStream = ByteArrayInputStream (bytes)
123+ FileUtils .copyInputStreamToFile(checkpointInputStream, ptdCheckpointFile)
124+ checkpointInputStream.close()
125+
126+ val trainedModule =
127+ TrainingModule .load(
128+ getTestFilePath(pteFilePath),
129+ getTestFilePath(ptdCheckpointFilePath),
130+ )
131+ for (inputDex in 0 .. (dataset.size - 1 ) step 2 ) {
132+ val targetDex = inputDex + 1
133+ val out =
134+ trainedModule.executeForwardBackward(
135+ " forward" ,
136+ EValue .from(dataset.get(inputDex)),
137+ EValue .from(dataset.get(targetDex)),
138+ )
139+ val outLoss = out [0 ].toTensor().dataAsFloatArray[0 ]
140+ Assert .assertTrue(outLoss < 0.1f )
141+ }
109142 }
110143
111144 @Test
@@ -118,7 +151,7 @@ class TrainingModuleE2ETest {
118151 FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
119152 pteInputStream.close()
120153
121- val module = TrainingModule .load(getTestFilePath(pteFilePath));
154+ val module = TrainingModule .load(getTestFilePath(pteFilePath))
122155 val params = module.namedParameters(" forward" )
123156
124157 Assert .assertEquals(4 , params.size)
@@ -127,27 +160,29 @@ class TrainingModuleE2ETest {
127160 assertContains(params, LIN2_WEIGHT )
128161 assertContains(params, LIN2_BIAS )
129162
130- val sgd = SGD .create(params, 0.5 );
131- val dataset = listOf<Tensor >(
132- Tensor .fromBlob(floatArrayOf(1.0f , 1.0f ), longArrayOf(1 , 2 )),
133- Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
134- Tensor .fromBlob(floatArrayOf(0.0f , 0.0f ), longArrayOf(1 , 2 )),
135- Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
136- Tensor .fromBlob(floatArrayOf(1.0f , 0.0f ), longArrayOf(1 , 2 )),
137- Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
138- Tensor .fromBlob(floatArrayOf(0.0f , 1.0f ), longArrayOf(1 , 2 )),
139- Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
140- )
163+ val sgd = SGD .create(params, 0.5 )
164+ val dataset =
165+ listOf<Tensor >(
166+ Tensor .fromBlob(floatArrayOf(1.0f , 1.0f ), longArrayOf(1 , 2 )),
167+ Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
168+ Tensor .fromBlob(floatArrayOf(0.0f , 0.0f ), longArrayOf(1 , 2 )),
169+ Tensor .fromBlob(longArrayOf(0 ), longArrayOf(1 )),
170+ Tensor .fromBlob(floatArrayOf(1.0f , 0.0f ), longArrayOf(1 , 2 )),
171+ Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
172+ Tensor .fromBlob(floatArrayOf(0.0f , 1.0f ), longArrayOf(1 , 2 )),
173+ Tensor .fromBlob(longArrayOf(1 ), longArrayOf(1 )),
174+ )
141175
142- val numEpochs = 5000 ;
176+ val numEpochs = 5000
143177 var finalLoss = Float .MAX_VALUE
144178
145179 for (i in 0 until numEpochs) {
146180 val inputDex = 2 * Random .nextInt(dataset.size / 2 )
147181 val targetDex = inputDex + 1
148182 val input = dataset.get(inputDex)
149183 val target = dataset.get(targetDex)
150- val out = module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
184+ val out =
185+ module.executeForwardBackward(" forward" , EValue .from(input), EValue .from(target))
151186 val gradients = module.namedGradients(" forward" )
152187
153188 if (i == 0 ) {
@@ -168,7 +203,9 @@ class TrainingModuleE2ETest {
168203 input.getDataAsFloatArray()[0 ],
169204 input.getDataAsFloatArray()[1 ],
170205 out [1 ].toTensor().getDataAsLongArray()[0 ],
171- target.getDataAsLongArray()[0 ]));
206+ target.getDataAsLongArray()[0 ],
207+ ),
208+ )
172209 }
173210
174211 sgd.step(gradients)
@@ -183,25 +220,33 @@ class TrainingModuleE2ETest {
183220 @Test
184221 @Throws(IOException ::class )
185222 fun testMissingPteFile () {
186- val exception = Assert .assertThrows(RuntimeException ::class .java) {
187- TrainingModule .load(getTestFilePath(MISSING_PTE_NAME ))
188- }
189- Assert .assertEquals(exception.message, " Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME ))
223+ val exception =
224+ Assert .assertThrows(RuntimeException ::class .java) {
225+ TrainingModule .load(getTestFilePath(MISSING_PTE_NAME ))
226+ }
227+ Assert .assertEquals(
228+ exception.message,
229+ " Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME ),
230+ )
190231 }
191232
192233 @Test
193234 @Throws(IOException ::class )
194235 fun testMissingPtdFile () {
195- val exception = Assert .assertThrows(RuntimeException ::class .java) {
196- val pteFilePath = " /xor.pte"
197- val pteFile = File (getTestFilePath(pteFilePath))
198- val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199- FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
200- pteInputStream.close()
201-
202- TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME ))
203- }
204- Assert .assertEquals(exception.message, " Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME ))
236+ val exception =
237+ Assert .assertThrows(RuntimeException ::class .java) {
238+ val pteFilePath = " /xor.pte"
239+ val pteFile = File (getTestFilePath(pteFilePath))
240+ val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
241+ FileUtils .copyInputStreamToFile(pteInputStream, pteFile)
242+ pteInputStream.close()
243+
244+ TrainingModule .load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME ))
245+ }
246+ Assert .assertEquals(
247+ exception.message,
248+ " Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME ),
249+ )
205250 }
206251
207252 companion object {
0 commit comments