diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h index 250241b9c9d..b2e36e0a1f2 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h @@ -145,7 +145,7 @@ __attribute__((objc_subclassing_restricted)) @return A retained ExecuTorchLLMMultimodalInput instance of type Audio. */ + (instancetype)inputWithAudio:(ExecuTorchLLMAudio *)audio - NS_SWIFT_NAME(init(audio:)) + NS_SWIFT_NAME(init(_:)) NS_RETURNS_RETAINED; @property(nonatomic, readonly) ExecuTorchLLMMultimodalInputType type; diff --git a/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift b/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift index 7281740c3af..3617245b8f8 100644 --- a/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift +++ b/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +import ExecuTorch import ExecuTorchLLM import XCTest @@ -98,10 +99,10 @@ extension UIImage { class MultimodalRunnerTest: XCTestCase { let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." - let userPrompt = "What's on the picture?" func testGemma() { let chatTemplate = "user\n%@\nmodel" + let userPrompt = "What's on the picture?" let sideSize: CGFloat = 896 let sequenceLength = 768 let bundle = Bundle(for: type(of: self)) @@ -156,6 +157,7 @@ class MultimodalRunnerTest: XCTestCase { func testLLaVA() { let chatTemplate = "USER: %@ ASSISTANT: " + let userPrompt = "What's on the picture?" let sideSize: CGFloat = 336 let sequenceLength = 768 let bundle = Bundle(for: type(of: self)) @@ -201,4 +203,47 @@ class MultimodalRunnerTest: XCTestCase { } XCTAssertTrue(text.lowercased().contains("waterfall")) } + + func testVoxtral() throws { + let chatTemplate = "%@[/INST]" + let userPrompt = "What is the audio about?" + let bundle = Bundle(for: type(of: self)) + guard let modelPath = bundle.path(forResource: "voxtral", ofType: "pte"), + let tokenizerPath = bundle.path(forResource: "voxtral_tokenizer_tekken", ofType: "json"), + let audioPath = bundle.path(forResource: "voxtral_input_features", ofType: "bin") else { + XCTFail("Couldn't find model or tokenizer files") + return + } + let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath) + var audioData = try Data(contentsOf: URL(fileURLWithPath: audioPath), options: .mappedIfSafe) + let floatSize = MemoryLayout.size + guard audioData.count % floatSize == 0 else { + XCTFail("Invalid audio data") + return + } + let bins = 128 + let frames = 3000 + let batchSize = audioData.count / floatSize / (bins * frames) + var text = "" + + do { + try runner.generate([ + MultimodalInput("[INST][BEGIN_AUDIO]"), + MultimodalInput(Audio( + float: audioData, + batchSize: batchSize, + bins: bins, + frames: frames + )), + MultimodalInput(String(format: chatTemplate, userPrompt)), + ], Config { + $0.maximumNewTokens = 256 + }) { token in + text += token + } + } catch { + XCTFail("Failed to generate text with error \(error)") + } + XCTAssertTrue(text.lowercased().contains("tattoo")) + } }