66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ import ExecuTorchLLM
910import SwiftUI
1011import UniformTypeIdentifiers
1112
12- import LLaMARunner
13-
1413class RunnerHolder : ObservableObject {
15- var llamaRunner : LLaMARunner ?
16- var llavaRunner : LLaVARunner ?
14+ var textRunner : TextRunner ?
15+ var multimodalRunner : MultimodalRunner ?
1716}
1817
1918extension UIImage {
@@ -347,15 +346,34 @@ struct ContentView: View {
347346
348347 switch modelType {
349348 case . llama, . qwen3, . phi4:
350- runnerHolder. llamaRunner = runnerHolder. llamaRunner ?? LLaMARunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
349+ runnerHolder. textRunner = runnerHolder. textRunner ?? TextRunner (
350+ modelPath: modelPath,
351+ tokenizerPath: tokenizerPath,
352+ specialTokens: [
353+ " <|begin_of_text|> " ,
354+ " <|end_of_text|> " ,
355+ " <|reserved_special_token_0|> " ,
356+ " <|reserved_special_token_1|> " ,
357+ " <|finetune_right_pad_id|> " ,
358+ " <|step_id|> " ,
359+ " <|start_header_id|> " ,
360+ " <|end_header_id|> " ,
361+ " <|eom_id|> " ,
362+ " <|eot_id|> " ,
363+ " <|python_tag|> "
364+ ] + ( 2 ..< 256 ) . map { " <|reserved_special_token_ \( $0) |> " }
365+ )
351366 case . llava:
352- runnerHolder. llavaRunner = runnerHolder. llavaRunner ?? LLaVARunner ( modelPath: modelPath, tokenizerPath: tokenizerPath)
367+ runnerHolder. multimodalRunner = runnerHolder. multimodalRunner ?? MultimodalRunner (
368+ modelPath: modelPath,
369+ tokenizerPath: tokenizerPath
370+ )
353371 }
354372
355373 guard !shouldStopGenerating else { return }
356374 switch modelType {
357375 case . llama, . qwen3, . phi4:
358- if let runner = runnerHolder. llamaRunner , !runner. isLoaded ( ) {
376+ if let runner = runnerHolder. textRunner , !runner. isLoaded ( ) {
359377 var error : Error ?
360378 let startLoadTime = Date ( )
361379 do {
@@ -385,7 +403,7 @@ struct ContentView: View {
385403 }
386404 }
387405 case . llava:
388- if let runner = runnerHolder. llavaRunner , !runner. isLoaded ( ) {
406+ if let runner = runnerHolder. multimodalRunner , !runner. isLoaded ( ) {
389407 var error : Error ?
390408 let startLoadTime = Date ( )
391409 do {
@@ -426,25 +444,21 @@ struct ContentView: View {
426444 }
427445 do {
428446 var tokens : [ String ] = [ ]
429- var rgbArray : [ UInt8 ] ?
430- let MAX_WIDTH = 336.0
431- var newHeight = 0.0
432- var imageBuffer : UnsafeMutableRawPointer ?
433447
434448 if let img = selectedImage {
435449 let llava_prompt = " \( text) ASSISTANT "
436-
437- newHeight = MAX_WIDTH * img. size. height / img. size. width
450+ let MAX_WIDTH = 336.0
451+ let newHeight = MAX_WIDTH * img. size. height / img. size. width
438452 let resizedImage = img. resized ( to: CGSize ( width: MAX_WIDTH, height: newHeight) )
439- rgbArray = resizedImage. toRGBArray ( )
440- imageBuffer = UnsafeMutableRawPointer ( mutating: rgbArray)
441-
442- try runnerHolder. llavaRunner? . generate ( imageBuffer!, width: MAX_WIDTH, height: newHeight, prompt: llava_prompt, sequenceLength: seq_len) { token in
443453
454+ try runnerHolder. multimodalRunner? . generate ( [
455+ MultimodalInput ( Image ( data: Data ( resizedImage. toRGBArray ( ) ?? [ ] ) , width: Int ( MAX_WIDTH) , height: Int ( newHeight. rounded ( ) ) , channels: 3 ) ) ,
456+ MultimodalInput ( llava_prompt) ,
457+ ] , sequenceLength: seq_len) { token in
444458 if token != llava_prompt {
445459 if token == " </s> " {
446460 shouldStopGenerating = true
447- runnerHolder. llavaRunner ? . stop ( )
461+ runnerHolder. multimodalRunner ? . stop ( )
448462 } else {
449463 tokens. append ( token)
450464 if tokens. count > 2 {
@@ -460,7 +474,7 @@ struct ContentView: View {
460474 }
461475 }
462476 if shouldStopGenerating {
463- runnerHolder. llavaRunner ? . stop ( )
477+ runnerHolder. multimodalRunner ? . stop ( )
464478 }
465479 }
466480 }
@@ -481,7 +495,7 @@ struct ContentView: View {
481495 prompt = String ( format: Constants . phi4PromptTemplate, text)
482496 }
483497
484- try runnerHolder. llamaRunner ? . generate ( prompt, sequenceLength: seq_len) { token in
498+ try runnerHolder. textRunner ? . generate ( prompt, sequenceLength: seq_len) { token in
485499
486500 if token != prompt {
487501 if token == " <|eot_id|> " {
@@ -534,7 +548,7 @@ struct ContentView: View {
534548 }
535549 }
536550 if shouldStopGenerating {
537- runnerHolder. llamaRunner ? . stop ( )
551+ runnerHolder. textRunner ? . stop ( )
538552 }
539553 }
540554 }
@@ -577,8 +591,8 @@ struct ContentView: View {
577591 return
578592 }
579593 runnerQueue. async {
580- runnerHolder. llamaRunner = nil
581- runnerHolder. llavaRunner = nil
594+ runnerHolder. textRunner = nil
595+ runnerHolder. multimodalRunner = nil
582596 }
583597 switch pickerType {
584598 case . model:
0 commit comments