Skip to content

Commit 1d81b4b

Browse files
committed
Add JNI wrapper for saving training parameters as PTD
Also makes adjustments for serialize.h utility to be visible for JNI.
1 parent fee2bd9 commit 1d81b4b

File tree

6 files changed

+153
-47
lines changed

6 files changed

+153
-47
lines changed

extension/android/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ endif()
147147

148148
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
149149
target_sources(executorch_jni PRIVATE jni/jni_layer_training.cpp jni/log.cpp)
150-
list(APPEND link_libraries extension_training)
150+
list(APPEND link_libraries extension_training extension_flat_tensor_serialize)
151151
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_EXTENSION_TRAINING=1)
152152
endif()
153153

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TrainingModuleE2ETest.kt

Lines changed: 90 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ import android.Manifest
1111
import android.util.Log
1212
import androidx.test.ext.junit.runners.AndroidJUnit4
1313
import androidx.test.rule.GrantPermissionRule
14+
import java.io.ByteArrayInputStream
1415
import java.io.File
1516
import java.io.IOException
1617
import java.net.URISyntaxException
18+
import kotlin.random.Random
19+
import kotlin.test.assertContains
1720
import org.apache.commons.io.FileUtils
1821
import org.junit.Assert
1922
import org.junit.Rule
2023
import org.junit.Test
2124
import org.junit.runner.RunWith
2225
import org.pytorch.executorch.TestFileUtils.getTestFilePath
23-
import kotlin.random.Random
24-
import kotlin.test.assertContains
2526

2627
/** Unit tests for [TrainingModule]. */
2728
@RunWith(AndroidJUnit4::class)
@@ -55,27 +56,29 @@ class TrainingModuleE2ETest {
5556
assertContains(params, LIN2_WEIGHT)
5657
assertContains(params, LIN2_BIAS)
5758

58-
val sgd = SGD.create(params, 0.5);
59-
val dataset = listOf<Tensor>(
60-
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
61-
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
62-
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
63-
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
64-
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
65-
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
66-
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
67-
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
68-
)
59+
val sgd = SGD.create(params, 0.5)
60+
val dataset =
61+
listOf<Tensor>(
62+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
63+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
64+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
65+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
66+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
67+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
68+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
69+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
70+
)
6971

70-
val numEpochs = 5000;
72+
val numEpochs = 5000
7173
var finalLoss = Float.MAX_VALUE
7274

7375
for (i in 0 until numEpochs) {
7476
val inputDex = 2 * Random.nextInt(dataset.size / 2)
7577
val targetDex = inputDex + 1
7678
val input = dataset.get(inputDex)
7779
val target = dataset.get(targetDex)
78-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
80+
val out =
81+
module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
7982
val gradients = module.namedGradients("forward")
8083

8184
if (i == 0) {
@@ -96,7 +99,9 @@ class TrainingModuleE2ETest {
9699
input.getDataAsFloatArray()[0],
97100
input.getDataAsFloatArray()[1],
98101
out[1].toTensor().getDataAsLongArray()[0],
99-
target.getDataAsLongArray()[0]));
102+
target.getDataAsLongArray()[0],
103+
),
104+
)
100105
}
101106

102107
sgd.step(gradients)
@@ -106,6 +111,34 @@ class TrainingModuleE2ETest {
106111
}
107112
}
108113
Assert.assertTrue(finalLoss < 0.1f)
114+
115+
// Check training performance continuity when exporting and loading from PTD checkpoint.
116+
val checkpoint = module.exportWeights("forward")
117+
val bytes = ByteArray(checkpoint.remaining())
118+
checkpoint.duplicate().get(bytes)
119+
120+
val ptdCheckpointFilePath = "/xor_checkpoint.ptd"
121+
val ptdCheckpointFile = File(getTestFilePath(ptdCheckpointFilePath))
122+
val checkpointInputStream = ByteArrayInputStream(bytes)
123+
FileUtils.copyInputStreamToFile(checkpointInputStream, ptdCheckpointFile)
124+
checkpointInputStream.close()
125+
126+
val trainedModule =
127+
TrainingModule.load(
128+
getTestFilePath(pteFilePath),
129+
getTestFilePath(ptdCheckpointFilePath),
130+
)
131+
for (inputDex in 0..(dataset.size - 1) step 2) {
132+
val targetDex = inputDex + 1
133+
val out =
134+
trainedModule.executeForwardBackward(
135+
"forward",
136+
EValue.from(dataset.get(inputDex)),
137+
EValue.from(dataset.get(targetDex)),
138+
)
139+
val outLoss = out[0].toTensor().dataAsFloatArray[0]
140+
Assert.assertTrue(outLoss < 0.1f)
141+
}
109142
}
110143

111144
@Test
@@ -118,7 +151,7 @@ class TrainingModuleE2ETest {
118151
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
119152
pteInputStream.close()
120153

121-
val module = TrainingModule.load(getTestFilePath(pteFilePath));
154+
val module = TrainingModule.load(getTestFilePath(pteFilePath))
122155
val params = module.namedParameters("forward")
123156

124157
Assert.assertEquals(4, params.size)
@@ -127,27 +160,29 @@ class TrainingModuleE2ETest {
127160
assertContains(params, LIN2_WEIGHT)
128161
assertContains(params, LIN2_BIAS)
129162

130-
val sgd = SGD.create(params, 0.5);
131-
val dataset = listOf<Tensor>(
132-
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
133-
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
134-
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
135-
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
136-
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
137-
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
138-
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
139-
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
140-
)
163+
val sgd = SGD.create(params, 0.5)
164+
val dataset =
165+
listOf<Tensor>(
166+
Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)),
167+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
168+
Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)),
169+
Tensor.fromBlob(longArrayOf(0), longArrayOf(1)),
170+
Tensor.fromBlob(floatArrayOf(1.0f, 0.0f), longArrayOf(1, 2)),
171+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
172+
Tensor.fromBlob(floatArrayOf(0.0f, 1.0f), longArrayOf(1, 2)),
173+
Tensor.fromBlob(longArrayOf(1), longArrayOf(1)),
174+
)
141175

142-
val numEpochs = 5000;
176+
val numEpochs = 5000
143177
var finalLoss = Float.MAX_VALUE
144178

145179
for (i in 0 until numEpochs) {
146180
val inputDex = 2 * Random.nextInt(dataset.size / 2)
147181
val targetDex = inputDex + 1
148182
val input = dataset.get(inputDex)
149183
val target = dataset.get(targetDex)
150-
val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
184+
val out =
185+
module.executeForwardBackward("forward", EValue.from(input), EValue.from(target))
151186
val gradients = module.namedGradients("forward")
152187

153188
if (i == 0) {
@@ -168,7 +203,9 @@ class TrainingModuleE2ETest {
168203
input.getDataAsFloatArray()[0],
169204
input.getDataAsFloatArray()[1],
170205
out[1].toTensor().getDataAsLongArray()[0],
171-
target.getDataAsLongArray()[0]));
206+
target.getDataAsLongArray()[0],
207+
),
208+
)
172209
}
173210

174211
sgd.step(gradients)
@@ -183,25 +220,33 @@ class TrainingModuleE2ETest {
183220
@Test
184221
@Throws(IOException::class)
185222
fun testMissingPteFile() {
186-
val exception = Assert.assertThrows(RuntimeException::class.java) {
187-
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
188-
}
189-
Assert.assertEquals(exception.message, "Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME))
223+
val exception =
224+
Assert.assertThrows(RuntimeException::class.java) {
225+
TrainingModule.load(getTestFilePath(MISSING_PTE_NAME))
226+
}
227+
Assert.assertEquals(
228+
exception.message,
229+
"Cannot load model path!! " + getTestFilePath(MISSING_PTE_NAME),
230+
)
190231
}
191232

192233
@Test
193234
@Throws(IOException::class)
194235
fun testMissingPtdFile() {
195-
val exception = Assert.assertThrows(RuntimeException::class.java) {
196-
val pteFilePath = "/xor.pte"
197-
val pteFile = File(getTestFilePath(pteFilePath))
198-
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
199-
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
200-
pteInputStream.close()
201-
202-
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
203-
}
204-
Assert.assertEquals(exception.message, "Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME))
236+
val exception =
237+
Assert.assertThrows(RuntimeException::class.java) {
238+
val pteFilePath = "/xor.pte"
239+
val pteFile = File(getTestFilePath(pteFilePath))
240+
val pteInputStream = javaClass.getResourceAsStream(pteFilePath)
241+
FileUtils.copyInputStreamToFile(pteInputStream, pteFile)
242+
pteInputStream.close()
243+
244+
TrainingModule.load(getTestFilePath(pteFilePath), getTestFilePath(MISSING_PTD_NAME))
245+
}
246+
Assert.assertEquals(
247+
exception.message,
248+
"Cannot load data path!! " + getTestFilePath(MISSING_PTD_NAME),
249+
)
205250
}
206251

207252
companion object {

extension/android/executorch_android/src/main/java/org/pytorch/executorch/TrainingModule.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import com.facebook.soloader.nativeloader.NativeLoader;
1515
import com.facebook.soloader.nativeloader.SystemDelegate;
1616
import java.io.File;
17+
import java.nio.ByteBuffer;
1718
import java.util.HashMap;
1819
import java.util.Map;
1920
import org.pytorch.executorch.annotations.Experimental;
@@ -114,6 +115,22 @@ public Map<String, Tensor> namedGradients(String methodName) {
114115
return namedGradientsNative(methodName);
115116
}
116117

117-
@DoNotStrip
118118
private native Map<String, Tensor> namedGradientsNative(String methodName);
119+
120+
/**
121+
* Exports the parameters of the specified method as a buffer that can be saved as a PTD file.
122+
*
123+
* @param methodName name of the ExecuTorch module method to export weights from.
124+
* @return buffer that contains the weights of the specified method
125+
*/
126+
public ByteBuffer exportWeights(String methodName) {
127+
if (!mHybridData.isValid()) {
128+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
129+
return ByteBuffer.allocateDirect(0);
130+
}
131+
return exportWeightsNative(methodName);
132+
}
133+
134+
@DoNotStrip
135+
private native ByteBuffer exportWeightsNative(String methodName);
119136
}

extension/android/jni/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
118118
"//xplat/executorch/backends/xnnpack:xnnpack_backend_static",
119119
"//xplat/executorch/examples/models/llama/runner:runner_static",
120120
"//xplat/executorch/examples/models/llava/runner:runner_static",
121+
"//xplat/executorch/extension/flat_tensor/serialize:serialize_static",
121122
"//xplat/executorch/extension/module:module_static",
122123
"//xplat/executorch/extension/runner_util:inputs_static",
123124
"//xplat/executorch/extension/tensor:tensor_static",

extension/android/jni/jni_layer_training.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/extension/android/jni/jni_layer_constants.h>
1010
#include <executorch/extension/android/jni/log.h>
1111
#include <executorch/extension/data_loader/file_data_loader.h>
12+
#include <executorch/extension/flat_tensor/serialize/serialize.h>
1213
#include <executorch/extension/tensor/tensor.h>
1314
#include <executorch/extension/training/module/training_module.h>
1415
#include <executorch/extension/training/optimizer/sgd.h>
@@ -210,6 +211,37 @@ class ExecuTorchTrainingJni
210211
return gradients;
211212
}
212213

214+
facebook::jni::local_ref<facebook::jni::JByteBuffer> exportWeights(
215+
facebook::jni::alias_ref<jstring> methodName) {
216+
auto method = methodName->toStdString();
217+
auto result = module_->named_parameters(method);
218+
if (!result.ok()) {
219+
facebook::jni::throwNewJavaException(
220+
"java/lang/Exception",
221+
"Getting named parameters for method %s failed with status 0x%" PRIx32,
222+
method.c_str(),
223+
static_cast<error_code_t>(result.error()));
224+
}
225+
std::map<std::string, executorch::aten::Tensor> tensorMap;
226+
for (auto& [layer, tensor] : result.get()) {
227+
tensorMap.emplace(std::string(layer), tensor);
228+
}
229+
std::ostringstream oss;
230+
auto saveError = executorch::extension::flat_tensor::save_ptd(
231+
oss, tensorMap, 16 /* tensor_alignment */);
232+
233+
if (saveError != executorch::runtime::Error::Ok) {
234+
facebook::jni::throwNewJavaException(
235+
"java/lang/Exception",
236+
"Saving parameters for method %s failed with status 0x%" PRIx32,
237+
method.c_str(),
238+
static_cast<error_code_t>(saveError));
239+
}
240+
std::string exportedWeights = oss.str();
241+
return facebook::jni::JByteBuffer::wrapBytes(
242+
(uint8_t*)exportedWeights.data(), exportedWeights.size());
243+
}
244+
213245
static void registerNatives() {
214246
registerHybrid({
215247
makeNativeMethod("initHybrid", ExecuTorchTrainingJni::initHybrid),
@@ -220,6 +252,8 @@ class ExecuTorchTrainingJni
220252
"namedParametersNative", ExecuTorchTrainingJni::namedParameters),
221253
makeNativeMethod(
222254
"namedGradientsNative", ExecuTorchTrainingJni::namedGradients),
255+
makeNativeMethod(
256+
"exportWeightsNative", ExecuTorchTrainingJni::exportWeights),
223257
});
224258
}
225259
};

extension/flat_tensor/serialize/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,12 @@ generate_flat_tensor_schema("${scalar_type_schema_srcs}" "scalar_type_schema")
6363
set(flat_tensor_schema_srcs flat_tensor.fbs)
6464
generate_flat_tensor_schema("${flat_tensor_schema_srcs}" "flat_tensor_schema")
6565
add_dependencies(flat_tensor_schema scalar_type_schema)
66+
67+
add_library(extension_flat_tensor_serialize serialize.cpp)
68+
target_include_directories(
69+
extension_flat_tensor_serialize
70+
PRIVATE ${_common_include_directories}
71+
${TORCH_INCLUDE_DIRS})
72+
target_link_libraries(extension_flat_tensor_serialize
73+
PRIVATE flat_tensor_schema
74+
)

0 commit comments

Comments
 (0)