|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
| 9 | +import ExecuTorch |
9 | 10 | import ExecuTorchLLM |
10 | 11 | import XCTest |
11 | 12 |
|
@@ -55,12 +56,68 @@ extension UIImage { |
55 | 56 | } |
56 | 57 |
|
57 | 58 | class MultimodalRunnerTest: XCTestCase { |
58 | | - let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: " |
59 | | - let assistantPrompt = "ASSISTANT: " |
60 | | - let userPrompt = "What's on the picture?" |
61 | | - let sequenceLength = 768 |
| 59 | + let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." |
| 60 | + |
| 61 | + func testGemma() { |
| 62 | + let chatTemplate = "<start_of_turn>user\n%@<end_of_turn>\n<start_of_turn>model" |
| 63 | + let userPrompt = "What's on the picture?" |
| 64 | + let sideSize: CGFloat = 896 |
| 65 | + let sequenceLength = 768 |
| 66 | + let bundle = Bundle(for: type(of: self)) |
| 67 | + guard let modelPath = bundle.path(forResource: "gemma3", ofType: "pte"), |
| 68 | + let tokenizerPath = bundle.path(forResource: "gemma3_tokenizer", ofType: "model"), |
| 69 | + let imagePath = bundle.path(forResource: "IMG_0005", ofType: "jpg"), |
| 70 | + let uiImage = UIImage(contentsOfFile: imagePath) else { |
| 71 | + XCTFail("Couldn't find model or tokenizer files") |
| 72 | + return |
| 73 | + } |
| 74 | + let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath) |
| 75 | + var text = "" |
| 76 | + |
| 77 | + do { |
| 78 | + try runner.generate([ |
| 79 | + MultimodalInput(systemPrompt), |
| 80 | + MultimodalInput(uiImage.asNormalizedImage(sideSize)), |
| 81 | + MultimodalInput(String(format: chatTemplate, userPrompt)), |
| 82 | + ], Config { |
| 83 | + $0.sequenceLength = sequenceLength |
| 84 | + }) { token in |
| 85 | + text += token |
| 86 | + if token == "<end_of_turn>" { |
| 87 | + runner.stop() |
| 88 | + } |
| 89 | + } |
| 90 | + } catch { |
| 91 | + XCTFail("Failed to generate text with error \(error)") |
| 92 | + } |
| 93 | + XCTAssertTrue(text.lowercased().contains("waterfall")) |
| 94 | + |
| 95 | + text = "" |
| 96 | + runner.reset() |
| 97 | + do { |
| 98 | + try runner.generate([ |
| 99 | + MultimodalInput(systemPrompt), |
| 100 | + MultimodalInput(uiImage.asNormalizedImage(sideSize)), |
| 101 | + MultimodalInput(String(format: chatTemplate, userPrompt)), |
| 102 | + ], Config { |
| 103 | + $0.sequenceLength = sequenceLength |
| 104 | + }) { token in |
| 105 | + text += token |
| 106 | + if token == "<end_of_turn>" { |
| 107 | + runner.stop() |
| 108 | + } |
| 109 | + } |
| 110 | + } catch { |
| 111 | + XCTFail("Failed to generate text with error \(error)") |
| 112 | + } |
| 113 | + XCTAssertTrue(text.lowercased().contains("waterfall")) |
| 114 | + } |
62 | 115 |
|
63 | 116 | func testLLaVA() { |
| 117 | + let chatTemplate = "USER: %@ ASSISTANT: " |
| 118 | + let userPrompt = "What's on the picture?" |
| 119 | + let sideSize: CGFloat = 336 |
| 120 | + let sequenceLength = 768 |
64 | 121 | let bundle = Bundle(for: type(of: self)) |
65 | 122 | guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"), |
66 | 123 | let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"), |
@@ -104,4 +161,47 @@ class MultimodalRunnerTest: XCTestCase { |
104 | 161 | } |
105 | 162 | XCTAssertTrue(text.lowercased().contains("waterfall")) |
106 | 163 | } |
| 164 | + |
| 165 | + func testVoxtral() throws { |
| 166 | + let chatTemplate = "%@[/INST]" |
| 167 | + let userPrompt = "What is the audio about?" |
| 168 | + let bundle = Bundle(for: type(of: self)) |
| 169 | + guard let modelPath = bundle.path(forResource: "voxtral", ofType: "pte"), |
| 170 | + let tokenizerPath = bundle.path(forResource: "voxtral_tokenizer_tekken", ofType: "json"), |
| 171 | + let audioPath = bundle.path(forResource: "voxtral_input_features", ofType: "bin") else { |
| 172 | + XCTFail("Couldn't find model or tokenizer files") |
| 173 | + return |
| 174 | + } |
| 175 | + let runner = MultimodalRunner(modelPath: modelPath, tokenizerPath: tokenizerPath) |
| 176 | + var audioData = try Data(contentsOf: URL(fileURLWithPath: audioPath), options: .mappedIfSafe) |
| 177 | + let floatSize = MemoryLayout<Float>.size |
| 178 | + guard audioData.count % floatSize == 0 else { |
| 179 | + XCTFail("Invalid audio data") |
| 180 | + return |
| 181 | + } |
| 182 | + let bins = 128 |
| 183 | + let frames = 3000 |
| 184 | + let batchSize = audioData.count / floatSize / (bins * frames) |
| 185 | + var text = "" |
| 186 | + |
| 187 | + do { |
| 188 | + try runner.generate([ |
| 189 | + MultimodalInput("<s>[INST][BEGIN_AUDIO]"), |
| 190 | + MultimodalInput(Audio( |
| 191 | + float: audioData, |
| 192 | + batchSize: batchSize, |
| 193 | + bins: bins, |
| 194 | + frames: frames |
| 195 | + )), |
| 196 | + MultimodalInput(String(format: chatTemplate, userPrompt)), |
| 197 | + ], Config { |
| 198 | + $0.maximumNewTokens = 256 |
| 199 | + }) { token in |
| 200 | + text += token |
| 201 | + } |
| 202 | + } catch { |
| 203 | + XCTFail("Failed to generate text with error \(error)") |
| 204 | + } |
| 205 | + XCTAssertTrue(text.lowercased().contains("tattoo")) |
| 206 | + } |
107 | 207 | } |
0 commit comments