55import android .app .Dialog ;
66import android .content .Context ;
77import android .os .Bundle ;
8+ import android .system .ErrnoException ;
89import android .text .method .ScrollingMovementMethod ;
910import android .util .Log ;
1011import android .util .Pair ;
1718import android .widget .Toast ;
1819
1920import java .io .File ;
21+ import java .nio .ByteBuffer ;
22+ import java .nio .ByteOrder ;
23+ import java .nio .FloatBuffer ;
24+ import java .nio .LongBuffer ;
2025import java .util .ArrayList ;
2126import java .util .Arrays ;
2227import java .util .List ;
2833import ai .onnxruntime .genai .Generator ;
2934import ai .onnxruntime .genai .GeneratorParams ;
3035import ai .onnxruntime .genai .Sequences ;
36+ import ai .onnxruntime .genai .Tensor ;
3137import ai .onnxruntime .genai .TokenizerStream ;
3238import ai .onnxruntime .genai .demo .databinding .ActivityMainBinding ;
3339import ai .onnxruntime .genai .Model ;
@@ -45,7 +51,7 @@ public class MainActivity extends AppCompatActivity implements Consumer<String>
4551 private TextView progressText ;
4652 private ImageButton settingsButton ;
4753 private static final String TAG = "genai.demo.MainActivity" ;
48- private int maxLength = 100 ;
54+ private int maxLength = 256 ;
4955 private float lengthPenalty = 1.0f ;
5056
5157 private static boolean fileExists (Context context , String fileName ) {
@@ -55,6 +61,14 @@ private static boolean fileExists(Context context, String fileName) {
5561
5662 @ Override
5763 protected void onCreate (Bundle savedInstanceState ) {
64+ try {
65+ // set ADSP_LIBRARY_PATH, QNN-specific
66+ String adspLibraryPath = getApplicationContext ().getApplicationInfo ().nativeLibraryDir ;
67+ android .system .Os .setenv ("ADSP_LIBRARY_PATH" , adspLibraryPath , true );
68+ } catch (ErrnoException e ) {
69+ throw new RuntimeException (e );
70+ }
71+
5872 super .onCreate (savedInstanceState );
5973
6074 binding = ActivityMainBinding .inflate (getLayoutInflater ());
@@ -69,8 +83,8 @@ protected void onCreate(Bundle savedInstanceState) {
6983
7084 // Trigger the download operation when the application is created
7185 try {
72- downloadModels (
73- getApplicationContext ());
86+ createModelFromPath ( "/data/local/tmp/phi3.5_qnn_qc/phi3.5-split-qnn-qc" );
87+ //downloadModels( getApplicationContext());
7488 } catch (GenAIException e ) {
7589 throw new RuntimeException (e );
7690 }
@@ -135,17 +149,63 @@ public void run() {
135149 GeneratorParams generatorParams = null ;
136150 Generator generator = null ;
137151 Sequences encodedPrompt = null ;
152+ Tensor attentionMask = null , positionIds = null ;
138153 try {
154+ encodedPrompt = tokenizer .encode (promptQuestion_formatted );
155+
139156 stream = tokenizer .createStream ();
140157
158+ int maxSequenceLength = 128 ;
159+ int contextLength = 4096 ;
160+
161+ int [] promptTokens = encodedPrompt .getSequence (0 );
162+ int numPromptTokens = promptTokens .length ;
163+
164+ if (numPromptTokens > maxSequenceLength ) {
165+ throw new RuntimeException ("numPromptTokens is greater than maxSequenceLength" );
166+ }
167+ if (numPromptTokens > contextLength ) {
168+ throw new RuntimeException ("numPromptTokens is greater than contextLength" );
169+ }
170+
171+ int paddingSize = maxSequenceLength - numPromptTokens ;
172+
173+ // paddedInputIds
174+ int [] paddedInputIds = new int [maxSequenceLength ];
175+ for (int i = 0 ; i < maxSequenceLength ; ++i ) {
176+ paddedInputIds [i ] = i < paddingSize ? 0 : promptTokens [i - paddingSize ];
177+ }
178+
179+ ByteOrder nativeOrder = ByteOrder .nativeOrder ();
180+
181+ // attentionMask
182+ int attentionMaskPaddingSize = contextLength - numPromptTokens ;
183+ ByteBuffer attentionMaskBuffer = ByteBuffer .allocateDirect (contextLength * 4 );
184+ attentionMaskBuffer .order (nativeOrder );
185+ FloatBuffer attentionMaskFloatBuffer = attentionMaskBuffer .asFloatBuffer ();
186+ for (int i = 0 ; i < contextLength ; i ++) {
187+ attentionMaskFloatBuffer .put (i < attentionMaskPaddingSize ? 0.0f : 1.0f );
188+ }
189+ attentionMask = new Tensor (attentionMaskBuffer , new long []{1 , contextLength }, Tensor .ElementType .float32 );
190+
191+ // positionIds
192+ ByteBuffer positionIdsBuffer = ByteBuffer .allocateDirect (maxSequenceLength * 8 );
193+ positionIdsBuffer .order (nativeOrder );
194+ LongBuffer positionIdsLongBuffer = positionIdsBuffer .asLongBuffer ();
195+ for (int i = 0 ; i < maxSequenceLength ; ++i ) {
196+ positionIdsLongBuffer .put (i < paddingSize ? 0 : i - paddingSize );
197+ }
198+ positionIds = new Tensor (positionIdsBuffer , new long []{1 , maxSequenceLength }, Tensor .ElementType .int64 );
199+
141200 generatorParams = model .createGeneratorParams ();
142201 //examples for optional parameters to format AI response
143202 // https://onnxruntime.ai/docs/genai/reference/config.html
144203 generatorParams .setSearchOption ("length_penalty" , lengthPenalty );
145204 generatorParams .setSearchOption ("max_length" , maxLength );
205+ generatorParams .setInput ("attention_mask_before_processor" , attentionMask );
206+ generatorParams .setInput ("position_ids" , positionIds );
146207
147- encodedPrompt = tokenizer .encode (promptQuestion_formatted );
148- generatorParams .setInput (encodedPrompt );
208+ generatorParams .setInput (paddedInputIds , maxSequenceLength , 1 );
149209
150210 generator = new Generator (model , generatorParams );
151211
@@ -175,7 +235,7 @@ public void run() {
175235 long totalTime = System .currentTimeMillis () - firstTokenTime ;
176236
177237 float promptProcessingTime = (firstTokenTime - startTime )/ 1000.0f ;
178- float tokensPerSecond = (1000 * (numTokens -1 )) / totalTime ;
238+ float tokensPerSecond = (1000.0f * (numTokens - 1 )) / totalTime ;
179239
180240 runOnUiThread (() -> {
181241 sendMsgIB .setEnabled (true );
@@ -192,6 +252,8 @@ public void run() {
192252 Log .e (TAG , "Exception occurred during model query: " + e .getMessage ());
193253 }
194254 finally {
255+ if (positionIds != null ) positionIds .close ();
256+ if (attentionMask != null ) attentionMask .close ();
195257 if (generator != null ) generator .close ();
196258 if (encodedPrompt != null ) encodedPrompt .close ();
197259 if (stream != null ) stream .close ();
@@ -217,8 +279,12 @@ protected void onDestroy() {
217279 super .onDestroy ();
218280 }
219281
220- private void downloadModels (Context context ) throws GenAIException {
282+ private void createModelFromPath (String path ) throws GenAIException {
283+ model = new Model (path );
284+ tokenizer = model .createTokenizer ();
285+ }
221286
287+ private void downloadModels (Context context ) throws GenAIException {
222288 final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/" ;
223289 List <String > files = Arrays .asList (
224290 "added_tokens.json" ,
0 commit comments