|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
| 9 | +import ExecuTorch |
9 | 10 | import ExecuTorchLLM |
10 | 11 | import XCTest |
11 | 12 |
|
@@ -98,10 +99,10 @@ extension UIImage { |
98 | 99 |
|
99 | 100 | class MultimodalRunnerTest: XCTestCase { |
100 | 101 | let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." |
101 | | - let userPrompt = "What's on the picture?" |
102 | 102 |
|
103 | 103 | func testGemma() { |
104 | 104 | let chatTemplate = "<start_of_turn>user\n%@<end_of_turn>\n<start_of_turn>model" |
| 105 | + let userPrompt = "What's on the picture?" |
105 | 106 | let sideSize: CGFloat = 896 |
106 | 107 | let sequenceLength = 768 |
107 | 108 | let bundle = Bundle(for: type(of: self)) |
@@ -156,6 +157,7 @@ class MultimodalRunnerTest: XCTestCase { |
156 | 157 |
|
157 | 158 | func testLLaVA() { |
158 | 159 | let chatTemplate = "USER: %@ ASSISTANT: " |
| 160 | + let userPrompt = "What's on the picture?" |
159 | 161 | let sideSize: CGFloat = 336 |
160 | 162 | let sequenceLength = 768 |
161 | 163 | let bundle = Bundle(for: type(of: self)) |
@@ -201,4 +203,47 @@ class MultimodalRunnerTest: XCTestCase { |
201 | 203 | } |
202 | 204 | XCTAssertTrue(text.lowercased().contains("waterfall")) |
203 | 205 | } |
| 206 | + |
| 207 | + func testVoxtral() throws { |
| 208 | + let chatTemplate = "%@[/INST]" |
| 209 | + let userPrompt = "What is the audio about?" |
| 210 | + let bundle = Bundle(for: type(of: self)) |
| 211 | + guard let modelPath = bundle.path(forResource: "voxtral", ofType: "pte"), |
| 212 | + let tokenizerPath = bundle.path(forResource: "voxtral_tokenizer_tekken", ofType: "json"), |
| 213 | + let audioPath = bundle.path(forResource: "voxtral_input_features", ofType: "bin") else { |
| 214 | + XCTFail("Couldn't find model or tokenizer files") |
| 215 | + return |
| 216 | + } |
| 217 | + let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath) |
| 218 | + var audioData = try Data(contentsOf: URL(fileURLWithPath: audioPath), options: .mappedIfSafe) |
| 219 | + let floatSize = MemoryLayout<Float>.size |
| 220 | + guard audioData.count % floatSize == 0 else { |
| 221 | + XCTFail("Invalid audio data") |
| 222 | + return |
| 223 | + } |
| 224 | + let bins = 128 |
| 225 | + let frames = 3000 |
| 226 | + let batchSize = audioData.count / floatSize / (bins * frames) |
| 227 | + var text = "" |
| 228 | + |
| 229 | + do { |
| 230 | + try runner.generate([ |
| 231 | + MultimodalInput("<s>[INST][BEGIN_AUDIO]"), |
| 232 | + MultimodalInput(Audio( |
| 233 | + float: audioData, |
| 234 | + batchSize: batchSize, |
| 235 | + bins: bins, |
| 236 | + frames: frames |
| 237 | + )), |
| 238 | + MultimodalInput(String(format: chatTemplate, userPrompt)), |
| 239 | + ], Config { |
| 240 | + $0.maximumNewTokens = 256 |
| 241 | + }) { token in |
| 242 | + text += token |
| 243 | + } |
| 244 | + } catch { |
| 245 | + XCTFail("Failed to generate text with error \(error)") |
| 246 | + } |
| 247 | + XCTAssertTrue(text.lowercased().contains("tattoo")) |
| 248 | + } |
204 | 249 | } |
0 commit comments