88
99package org .pytorch .executorch .extension .llm ;
1010
11+ import com .facebook .jni .HybridData ;
12+ import com .facebook .jni .annotations .DoNotStrip ;
1113import java .io .File ;
1214import java .util .List ;
1315import org .pytorch .executorch .ExecuTorchRuntime ;
@@ -26,19 +28,18 @@ public class LlmModule {
2628 public static final int MODEL_TYPE_TEXT_VISION = 2 ;
2729 public static final int MODEL_TYPE_MULTIMODAL = 2 ;
2830
29- private long mNativeHandle ;
31+ private final HybridData mHybridData ;
3032 private static final int DEFAULT_SEQ_LEN = 128 ;
3133 private static final boolean DEFAULT_ECHO = true ;
3234
33- private static native long nativeCreate (
35+ @ DoNotStrip
36+ private static native HybridData initHybrid (
3437 int modelType ,
3538 String modulePath ,
3639 String tokenizerPath ,
3740 float temperature ,
3841 List <String > dataFiles );
3942
40- private static native void nativeDestroy (long nativeHandle );
41-
4243 /**
4344 * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
4445 * dataFiles.
@@ -60,7 +61,7 @@ public LlmModule(
6061 throw new RuntimeException ("Cannot load tokenizer path " + tokenizerPath );
6162 }
6263
63- mNativeHandle = nativeCreate (modelType , modulePath , tokenizerPath , temperature , dataFiles );
64+ mHybridData = initHybrid (modelType , modulePath , tokenizerPath , temperature , dataFiles );
6465 }
6566
6667 /**
@@ -106,16 +107,7 @@ public LlmModule(LlmModuleConfig config) {
106107 }
107108
108109 public void resetNative () {
109- if (mNativeHandle != 0 ) {
110- nativeDestroy (mNativeHandle );
111- mNativeHandle = 0 ;
112- }
113- }
114-
115- @ Override
116- protected void finalize () throws Throwable {
117- resetNative ();
118- super .finalize ();
110+ mHybridData .resetNative ();
119111 }
120112
121113 /**
@@ -158,12 +150,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
158150 * @param llmCallback callback object to receive results
159151 * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
160152 */
161- public int generate (String prompt , int seqLen , LlmCallback llmCallback , boolean echo ) {
162- return nativeGenerate (mNativeHandle , prompt , seqLen , llmCallback , echo );
163- }
164-
165- private static native int nativeGenerate (
166- long nativeHandle , String prompt , int seqLen , LlmCallback llmCallback , boolean echo );
153+ public native int generate (String prompt , int seqLen , LlmCallback llmCallback , boolean echo );
167154
168155 /**
169156 * Start generating tokens from the module.
@@ -219,15 +206,14 @@ public int generate(
219206 */
220207 @ Experimental
221208 public long prefillImages (int [] image , int width , int height , int channels ) {
222- int nativeResult = nativeAppendImagesInput ( mNativeHandle , image , width , height , channels );
209+ int nativeResult = appendImagesInput ( image , width , height , channels );
223210 if (nativeResult != 0 ) {
224211 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
225212 }
226213 return 0 ;
227214 }
228215
229- private static native int nativeAppendImagesInput (
230- long nativeHandle , int [] image , int width , int height , int channels );
216+ private native int appendImagesInput (int [] image , int width , int height , int channels );
231217
232218 /**
233219 * Prefill a multimodal Module with the given images input.
@@ -242,16 +228,15 @@ private static native int nativeAppendImagesInput(
242228 */
243229 @ Experimental
244230 public long prefillImages (float [] image , int width , int height , int channels ) {
245- int nativeResult =
246- nativeAppendNormalizedImagesInput (mNativeHandle , image , width , height , channels );
231+ int nativeResult = appendNormalizedImagesInput (image , width , height , channels );
247232 if (nativeResult != 0 ) {
248233 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
249234 }
250235 return 0 ;
251236 }
252237
253- private static native int nativeAppendNormalizedImagesInput (
254- long nativeHandle , float [] image , int width , int height , int channels );
238+ private native int appendNormalizedImagesInput (
239+ float [] image , int width , int height , int channels );
255240
256241 /**
257242 * Prefill a multimodal Module with the given audio input.
@@ -266,15 +251,14 @@ private static native int nativeAppendNormalizedImagesInput(
266251 */
267252 @ Experimental
268253 public long prefillAudio (byte [] audio , int batch_size , int n_bins , int n_frames ) {
269- int nativeResult = nativeAppendAudioInput ( mNativeHandle , audio , batch_size , n_bins , n_frames );
254+ int nativeResult = appendAudioInput ( audio , batch_size , n_bins , n_frames );
270255 if (nativeResult != 0 ) {
271256 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
272257 }
273258 return 0 ;
274259 }
275260
276- private static native int nativeAppendAudioInput (
277- long nativeHandle , byte [] audio , int batch_size , int n_bins , int n_frames );
261+ private native int appendAudioInput (byte [] audio , int batch_size , int n_bins , int n_frames );
278262
279263 /**
280264 * Prefill a multimodal Module with the given audio input.
@@ -289,16 +273,14 @@ private static native int nativeAppendAudioInput(
289273 */
290274 @ Experimental
291275 public long prefillAudio (float [] audio , int batch_size , int n_bins , int n_frames ) {
292- int nativeResult =
293- nativeAppendAudioInputFloat (mNativeHandle , audio , batch_size , n_bins , n_frames );
276+ int nativeResult = appendAudioInputFloat (audio , batch_size , n_bins , n_frames );
294277 if (nativeResult != 0 ) {
295278 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
296279 }
297280 return 0 ;
298281 }
299282
300- private static native int nativeAppendAudioInputFloat (
301- long nativeHandle , float [] audio , int batch_size , int n_bins , int n_frames );
283+ private native int appendAudioInputFloat (float [] audio , int batch_size , int n_bins , int n_frames );
302284
303285 /**
304286 * Prefill a multimodal Module with the given raw audio input.
@@ -313,16 +295,15 @@ private static native int nativeAppendAudioInputFloat(
313295 */
314296 @ Experimental
315297 public long prefillRawAudio (byte [] audio , int batch_size , int n_channels , int n_samples ) {
316- int nativeResult =
317- nativeAppendRawAudioInput (mNativeHandle , audio , batch_size , n_channels , n_samples );
298+ int nativeResult = appendRawAudioInput (audio , batch_size , n_channels , n_samples );
318299 if (nativeResult != 0 ) {
319300 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
320301 }
321302 return 0 ;
322303 }
323304
324- private static native int nativeAppendRawAudioInput (
325- long nativeHandle , byte [] audio , int batch_size , int n_channels , int n_samples );
305+ private native int appendRawAudioInput (
306+ byte [] audio , int batch_size , int n_channels , int n_samples );
326307
327308 /**
328309 * Prefill a multimodal Module with the given text input.
@@ -334,38 +315,28 @@ private static native int nativeAppendRawAudioInput(
334315 */
335316 @ Experimental
336317 public long prefillPrompt (String prompt ) {
337- int nativeResult = nativeAppendTextInput ( mNativeHandle , prompt );
318+ int nativeResult = appendTextInput ( prompt );
338319 if (nativeResult != 0 ) {
339320 throw new RuntimeException ("Prefill failed with error code: " + nativeResult );
340321 }
341322 return 0 ;
342323 }
343324
344325 // returns status
345- private static native int nativeAppendTextInput ( long nativeHandle , String prompt );
326+ private native int appendTextInput ( String prompt );
346327
347328 /**
348329 * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
349330 *
350331 * <p>The startPos will be reset to 0.
351332 */
352- public void resetContext () {
353- nativeResetContext (mNativeHandle );
354- }
355-
356- private static native void nativeResetContext (long nativeHandle );
333+ public native void resetContext ();
357334
358335 /** Stop current generate() before it finishes. */
359- public void stop () {
360- nativeStop (mNativeHandle );
361- }
362-
363- private static native void nativeStop (long nativeHandle );
336+ @ DoNotStrip
337+ public native void stop ();
364338
365339 /** Force loading the module. Otherwise the model is loaded during first generate(). */
366- public int load () {
367- return nativeLoad (mNativeHandle );
368- }
369-
370- private static native int nativeLoad (long nativeHandle );
340+ @ DoNotStrip
341+ public native int load ();
371342}
0 commit comments