Skip to content

Commit 0485e18

Browse files
Copilotkirklandsign
andcommitted
Revert jni_layer_llama.cpp and LlmModule.java to split into separate PR
Co-authored-by: kirklandsign <[email protected]>
1 parent e18fe7b commit 0485e18

File tree

3 files changed

+320
-619
lines changed

3 files changed

+320
-619
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
package org.pytorch.executorch.extension.llm;
1010

11+
import com.facebook.jni.HybridData;
12+
import com.facebook.jni.annotations.DoNotStrip;
1113
import java.io.File;
1214
import java.util.List;
1315
import org.pytorch.executorch.ExecuTorchRuntime;
@@ -26,19 +28,18 @@ public class LlmModule {
2628
public static final int MODEL_TYPE_TEXT_VISION = 2;
2729
public static final int MODEL_TYPE_MULTIMODAL = 2;
2830

29-
private long mNativeHandle;
31+
private final HybridData mHybridData;
3032
private static final int DEFAULT_SEQ_LEN = 128;
3133
private static final boolean DEFAULT_ECHO = true;
3234

33-
private static native long nativeCreate(
35+
@DoNotStrip
36+
private static native HybridData initHybrid(
3437
int modelType,
3538
String modulePath,
3639
String tokenizerPath,
3740
float temperature,
3841
List<String> dataFiles);
3942

40-
private static native void nativeDestroy(long nativeHandle);
41-
4243
/**
4344
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
4445
* dataFiles.
@@ -60,7 +61,7 @@ public LlmModule(
6061
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
6162
}
6263

63-
mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles);
64+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles);
6465
}
6566

6667
/**
@@ -106,16 +107,7 @@ public LlmModule(LlmModuleConfig config) {
106107
}
107108

108109
public void resetNative() {
109-
if (mNativeHandle != 0) {
110-
nativeDestroy(mNativeHandle);
111-
mNativeHandle = 0;
112-
}
113-
}
114-
115-
@Override
116-
protected void finalize() throws Throwable {
117-
resetNative();
118-
super.finalize();
110+
mHybridData.resetNative();
119111
}
120112

121113
/**
@@ -158,12 +150,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
158150
* @param llmCallback callback object to receive results
159151
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
160152
*/
161-
public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
162-
return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo);
163-
}
164-
165-
private static native int nativeGenerate(
166-
long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
153+
public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
167154

168155
/**
169156
* Start generating tokens from the module.
@@ -219,15 +206,14 @@ public int generate(
219206
*/
220207
@Experimental
221208
public long prefillImages(int[] image, int width, int height, int channels) {
222-
int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels);
209+
int nativeResult = appendImagesInput(image, width, height, channels);
223210
if (nativeResult != 0) {
224211
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
225212
}
226213
return 0;
227214
}
228215

229-
private static native int nativeAppendImagesInput(
230-
long nativeHandle, int[] image, int width, int height, int channels);
216+
private native int appendImagesInput(int[] image, int width, int height, int channels);
231217

232218
/**
233219
* Prefill a multimodal Module with the given images input.
@@ -242,16 +228,15 @@ private static native int nativeAppendImagesInput(
242228
*/
243229
@Experimental
244230
public long prefillImages(float[] image, int width, int height, int channels) {
245-
int nativeResult =
246-
nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels);
231+
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
247232
if (nativeResult != 0) {
248233
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
249234
}
250235
return 0;
251236
}
252237

253-
private static native int nativeAppendNormalizedImagesInput(
254-
long nativeHandle, float[] image, int width, int height, int channels);
238+
private native int appendNormalizedImagesInput(
239+
float[] image, int width, int height, int channels);
255240

256241
/**
257242
* Prefill a multimodal Module with the given audio input.
@@ -266,15 +251,14 @@ private static native int nativeAppendNormalizedImagesInput(
266251
*/
267252
@Experimental
268253
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
269-
int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames);
254+
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
270255
if (nativeResult != 0) {
271256
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
272257
}
273258
return 0;
274259
}
275260

276-
private static native int nativeAppendAudioInput(
277-
long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames);
261+
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
278262

279263
/**
280264
* Prefill a multimodal Module with the given audio input.
@@ -289,16 +273,14 @@ private static native int nativeAppendAudioInput(
289273
*/
290274
@Experimental
291275
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
292-
int nativeResult =
293-
nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames);
276+
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
294277
if (nativeResult != 0) {
295278
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
296279
}
297280
return 0;
298281
}
299282

300-
private static native int nativeAppendAudioInputFloat(
301-
long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames);
283+
private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
302284

303285
/**
304286
* Prefill a multimodal Module with the given raw audio input.
@@ -313,16 +295,15 @@ private static native int nativeAppendAudioInputFloat(
313295
*/
314296
@Experimental
315297
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
316-
int nativeResult =
317-
nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples);
298+
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
318299
if (nativeResult != 0) {
319300
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
320301
}
321302
return 0;
322303
}
323304

324-
private static native int nativeAppendRawAudioInput(
325-
long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples);
305+
private native int appendRawAudioInput(
306+
byte[] audio, int batch_size, int n_channels, int n_samples);
326307

327308
/**
328309
* Prefill a multimodal Module with the given text input.
@@ -334,38 +315,28 @@ private static native int nativeAppendRawAudioInput(
334315
*/
335316
@Experimental
336317
public long prefillPrompt(String prompt) {
337-
int nativeResult = nativeAppendTextInput(mNativeHandle, prompt);
318+
int nativeResult = appendTextInput(prompt);
338319
if (nativeResult != 0) {
339320
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
340321
}
341322
return 0;
342323
}
343324

344325
// returns status
345-
private static native int nativeAppendTextInput(long nativeHandle, String prompt);
326+
private native int appendTextInput(String prompt);
346327

347328
/**
348329
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
349330
*
350331
* <p>The startPos will be reset to 0.
351332
*/
352-
public void resetContext() {
353-
nativeResetContext(mNativeHandle);
354-
}
355-
356-
private static native void nativeResetContext(long nativeHandle);
333+
public native void resetContext();
357334

358335
/** Stop current generate() before it finishes. */
359-
public void stop() {
360-
nativeStop(mNativeHandle);
361-
}
362-
363-
private static native void nativeStop(long nativeHandle);
336+
@DoNotStrip
337+
public native void stop();
364338

365339
/** Force loading the module. Otherwise the model is loaded during first generate(). */
366-
public int load() {
367-
return nativeLoad(mNativeHandle);
368-
}
369-
370-
private static native int nativeLoad(long nativeHandle);
340+
@DoNotStrip
341+
public native int load();
371342
}

extension/android/jni/jni_layer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,10 @@ Java_org_pytorch_executorch_Module_nativeEtdump(
717717
} // extern "C"
718718

719719
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
720-
extern void register_natives_for_llm(JNIEnv* env);
720+
extern void register_natives_for_llm();
721721
#else
722722
// No op if we don't build LLM
723-
void register_natives_for_llm(JNIEnv* /* env */) {}
723+
void register_natives_for_llm() {}
724724
#endif
725725

726726
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
@@ -785,7 +785,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
785785

786786
// Register native methods
787787
register_natives_for_module(env);
788-
register_natives_for_llm(env);
788+
register_natives_for_llm();
789789
register_natives_for_runtime(env);
790790
register_natives_for_training(env);
791791

0 commit comments

Comments
 (0)