Skip to content

Commit 5f8a086

Browse files
authored
whisper.swiftui : add model download list & bench methods (#2546)
* swift : fix resources & exclude build * whisper : impl whisper_timings struct & api * whisper.swiftui : model list & bench methods * whisper : return ptr for whisper_get_timings * revert unnecessary change * whisper : avoid designated initializer * whisper.swiftui: code style changes * whisper.swiftui : get device name / os from UIDevice * whisper.swiftui : fix UIDevice usage * whisper.swiftui : add memcpy and ggml_mul_mat (commented)
1 parent a28d82e commit 5f8a086

File tree

9 files changed

+403
-19
lines changed

9 files changed

+403
-19
lines changed

Package.swift

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ let package = Package(
1818
name: "whisper",
1919
path: ".",
2020
exclude: [
21+
"build",
2122
"bindings",
2223
"cmake",
23-
"coreml",
2424
"examples",
25-
"extra",
25+
"scripts",
2626
"models",
2727
"samples",
2828
"tests",
2929
"CMakeLists.txt",
30-
"Makefile"
30+
"Makefile",
31+
"ggml/src/ggml-metal-embed.metal"
3132
],
3233
sources: [
3334
"ggml/src/ggml.c",
@@ -38,7 +39,7 @@ let package = Package(
3839
"ggml/src/ggml-quants.c",
3940
"ggml/src/ggml-metal.m"
4041
],
41-
resources: [.process("ggml-metal.metal")],
42+
resources: [.process("ggml/src/ggml-metal.metal")],
4243
publicHeadersPath: "spm-headers",
4344
cSettings: [
4445
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),

examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import Foundation
2+
import UIKit
23
import whisper
34

45
enum WhisperError: Error {
@@ -55,11 +56,93 @@ actor WhisperContext {
5556
return transcription
5657
}
5758

59+
static func benchMemcpy(nThreads: Int32) async -> String {
60+
return String.init(cString: whisper_bench_memcpy_str(nThreads))
61+
}
62+
63+
static func benchGgmlMulMat(nThreads: Int32) async -> String {
64+
return String.init(cString: whisper_bench_ggml_mul_mat_str(nThreads))
65+
}
66+
67+
private func systemInfo() -> String {
68+
var info = ""
69+
if (ggml_cpu_has_neon() != 0) { info += "NEON " }
70+
if (ggml_cpu_has_metal() != 0) { info += "METAL " }
71+
if (ggml_cpu_has_blas() != 0) { info += "BLAS " }
72+
return String(info.dropLast())
73+
}
74+
75+
func benchFull(modelName: String, nThreads: Int32) async -> String {
76+
let nMels = whisper_model_n_mels(context)
77+
if (whisper_set_mel(context, nil, 0, nMels) != 0) {
78+
return "error: failed to set mel"
79+
}
80+
81+
// heat encoder
82+
if (whisper_encode(context, 0, nThreads) != 0) {
83+
return "error: failed to encode"
84+
}
85+
86+
var tokens = [whisper_token](repeating: 0, count: 512)
87+
88+
// prompt heat
89+
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {
90+
return "error: failed to decode"
91+
}
92+
93+
// text-generation heat
94+
if (whisper_decode(context, &tokens, 1, 256, nThreads) != 0) {
95+
return "error: failed to decode"
96+
}
97+
98+
whisper_reset_timings(context)
99+
100+
// actual run
101+
if (whisper_encode(context, 0, nThreads) != 0) {
102+
return "error: failed to encode"
103+
}
104+
105+
// text-generation
106+
for i in 0..<256 {
107+
if (whisper_decode(context, &tokens, 1, Int32(i), nThreads) != 0) {
108+
return "error: failed to decode"
109+
}
110+
}
111+
112+
// batched decoding
113+
for _ in 0..<64 {
114+
if (whisper_decode(context, &tokens, 5, 0, nThreads) != 0) {
115+
return "error: failed to decode"
116+
}
117+
}
118+
119+
// prompt processing
120+
for _ in 0..<16 {
121+
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {
122+
return "error: failed to decode"
123+
}
124+
}
125+
126+
whisper_print_timings(context)
127+
128+
let deviceModel = await UIDevice.current.model
129+
let systemName = await UIDevice.current.systemName
130+
let systemInfo = self.systemInfo()
131+
let timings: whisper_timings = whisper_get_timings(context).pointee
132+
let encodeMs = String(format: "%.2f", timings.encode_ms)
133+
let decodeMs = String(format: "%.2f", timings.decode_ms)
134+
let batchdMs = String(format: "%.2f", timings.batchd_ms)
135+
let promptMs = String(format: "%.2f", timings.prompt_ms)
136+
return "| \(deviceModel) | \(systemName) | \(systemInfo) | \(modelName) | \(nThreads) | 1 | \(encodeMs) | \(decodeMs) | \(batchdMs) | \(promptMs) | <todo> |"
137+
}
138+
58139
static func createContext(path: String) throws -> WhisperContext {
59140
var params = whisper_context_default_params()
60141
#if targetEnvironment(simulator)
61142
params.use_gpu = false
62143
print("Running on the simulator, using CPU")
144+
#else
145+
params.flash_attn = true // Enabled by default for Metal
63146
#endif
64147
let context = whisper_init_from_file_with_params(path, params)
65148
if let context {
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import Foundation
2+
3+
struct Model: Identifiable {
4+
var id = UUID()
5+
var name: String
6+
var info: String
7+
var url: String
8+
9+
var filename: String
10+
var fileURL: URL {
11+
FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
12+
}
13+
14+
func fileExists() -> Bool {
15+
FileManager.default.fileExists(atPath: fileURL.path)
16+
}
17+
}

examples/whisper.swiftui/whisper.swiftui.demo/Models/WhisperState.swift

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
1414
private var recordedFile: URL? = nil
1515
private var audioPlayer: AVAudioPlayer?
1616

17-
private var modelUrl: URL? {
17+
private var builtInModelUrl: URL? {
1818
Bundle.main.url(forResource: "ggml-base.en", withExtension: "bin", subdirectory: "models")
1919
}
2020

@@ -28,23 +28,59 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
2828

2929
override init() {
3030
super.init()
31+
loadModel()
32+
}
33+
34+
func loadModel(path: URL? = nil, log: Bool = true) {
3135
do {
32-
try loadModel()
36+
whisperContext = nil
37+
if (log) { messageLog += "Loading model...\n" }
38+
let modelUrl = path ?? builtInModelUrl
39+
if let modelUrl {
40+
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
41+
if (log) { messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" }
42+
} else {
43+
if (log) { messageLog += "Could not locate model\n" }
44+
}
3345
canTranscribe = true
3446
} catch {
3547
print(error.localizedDescription)
36-
messageLog += "\(error.localizedDescription)\n"
48+
if (log) { messageLog += "\(error.localizedDescription)\n" }
3749
}
3850
}
39-
40-
private func loadModel() throws {
41-
messageLog += "Loading model...\n"
42-
if let modelUrl {
43-
whisperContext = try WhisperContext.createContext(path: modelUrl.path())
44-
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
45-
} else {
46-
messageLog += "Could not locate model\n"
51+
52+
func benchCurrentModel() async {
53+
if whisperContext == nil {
54+
messageLog += "Cannot bench without loaded model\n"
55+
return
4756
}
57+
messageLog += "Running benchmark for loaded model\n"
58+
let result = await whisperContext?.benchFull(modelName: "<current>", nThreads: Int32(min(4, cpuCount())))
59+
if (result != nil) { messageLog += result! + "\n" }
60+
}
61+
62+
func bench(models: [Model]) async {
63+
let nThreads = Int32(min(4, cpuCount()))
64+
65+
// messageLog += "Running memcpy benchmark\n"
66+
// messageLog += await WhisperContext.benchMemcpy(nThreads: nThreads) + "\n"
67+
//
68+
// messageLog += "Running ggml_mul_mat benchmark with \(nThreads) threads\n"
69+
// messageLog += await WhisperContext.benchGgmlMulMat(nThreads: nThreads) + "\n"
70+
71+
messageLog += "Running benchmark for all downloaded models\n"
72+
messageLog += "| CPU | OS | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |\n"
73+
messageLog += "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n"
74+
for model in models {
75+
loadModel(path: model.fileURL, log: false)
76+
if whisperContext == nil {
77+
messageLog += "Cannot bench without loaded model\n"
78+
break
79+
}
80+
let result = await whisperContext?.benchFull(modelName: model.name, nThreads: nThreads)
81+
if (result != nil) { messageLog += result! + "\n" }
82+
}
83+
messageLog += "Benchmarking completed\n"
4884
}
4985

5086
func transcribeSample() async {
@@ -160,3 +196,8 @@ class WhisperState: NSObject, ObservableObject, AVAudioRecorderDelegate {
160196
isRecording = false
161197
}
162198
}
199+
200+
201+
fileprivate func cpuCount() -> Int {
202+
ProcessInfo.processInfo.processorCount
203+
}

examples/whisper.swiftui/whisper.swiftui.demo/UI/ContentView.swift

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import SwiftUI
22
import AVFoundation
3+
import Foundation
34

45
struct ContentView: View {
56
@StateObject var whisperState = WhisperState()
@@ -29,15 +30,125 @@ struct ContentView: View {
2930
Text(verbatim: whisperState.messageLog)
3031
.frame(maxWidth: .infinity, alignment: .leading)
3132
}
33+
.font(.footnote)
34+
.padding()
35+
.background(Color.gray.opacity(0.1))
36+
.cornerRadius(10)
37+
38+
HStack {
39+
Button("Clear Logs", action: {
40+
whisperState.messageLog = ""
41+
})
42+
.font(.footnote)
43+
.buttonStyle(.bordered)
44+
45+
Button("Copy Logs", action: {
46+
UIPasteboard.general.string = whisperState.messageLog
47+
})
48+
.font(.footnote)
49+
.buttonStyle(.bordered)
50+
51+
Button("Bench", action: {
52+
Task {
53+
await whisperState.benchCurrentModel()
54+
}
55+
})
56+
.font(.footnote)
57+
.buttonStyle(.bordered)
58+
.disabled(!whisperState.canTranscribe)
59+
60+
Button("Bench All", action: {
61+
Task {
62+
await whisperState.bench(models: ModelsView.getDownloadedModels())
63+
}
64+
})
65+
.font(.footnote)
66+
.buttonStyle(.bordered)
67+
.disabled(!whisperState.canTranscribe)
68+
}
69+
70+
NavigationLink(destination: ModelsView(whisperState: whisperState)) {
71+
Text("View Models")
72+
}
73+
.font(.footnote)
74+
.padding()
3275
}
3376
.navigationTitle("Whisper SwiftUI Demo")
3477
.padding()
3578
}
3679
}
37-
}
3880

39-
struct ContentView_Previews: PreviewProvider {
40-
static var previews: some View {
41-
ContentView()
81+
struct ModelsView: View {
82+
@ObservedObject var whisperState: WhisperState
83+
@Environment(\.dismiss) var dismiss
84+
85+
private static let models: [Model] = [
86+
Model(name: "tiny", info: "(F16, 75 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin", filename: "tiny.bin"),
87+
Model(name: "tiny-q5_1", info: "(31 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny-q5_1.bin", filename: "tiny-q5_1.bin"),
88+
Model(name: "tiny-q8_0", info: "(42 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny-q8_0.bin", filename: "tiny-q8_0.bin"),
89+
Model(name: "tiny.en", info: "(F16, 75 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin", filename: "tiny.en.bin"),
90+
Model(name: "tiny.en-q5_1", info: "(31 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin", filename: "tiny.en-q5_1.bin"),
91+
Model(name: "tiny.en-q8_0", info: "(42 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin", filename: "tiny.en-q8_0.bin"),
92+
Model(name: "base", info: "(F16, 142 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin", filename: "base.bin"),
93+
Model(name: "base-q5_1", info: "(57 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base-q5_1.bin", filename: "base-q5_1.bin"),
94+
Model(name: "base-q8_0", info: "(78 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base-q8_0.bin", filename: "base-q8_0.bin"),
95+
Model(name: "base.en", info: "(F16, 142 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin", filename: "base.en.bin"),
96+
Model(name: "base.en-q5_1", info: "(57 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin", filename: "base.en-q5_1.bin"),
97+
Model(name: "base.en-q8_0", info: "(78 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q8_0.bin", filename: "base.en-q8_0.bin"),
98+
Model(name: "small", info: "(F16, 466 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin", filename: "small.bin"),
99+
Model(name: "small-q5_1", info: "(181 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small-q5_1.bin", filename: "small-q5_1.bin"),
100+
Model(name: "small-q8_0", info: "(252 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small-q8_0.bin", filename: "small-q8_0.bin"),
101+
Model(name: "small.en", info: "(F16, 466 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin", filename: "small.en.bin"),
102+
Model(name: "small.en-q5_1", info: "(181 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q5_1.bin", filename: "small.en-q5_1.bin"),
103+
Model(name: "small.en-q8_0", info: "(252 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q8_0.bin", filename: "small.en-q8_0.bin"),
104+
Model(name: "medium", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin", filename: "medium.bin"),
105+
Model(name: "medium-q5_0", info: "(514 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q5_0.bin", filename: "medium-q5_0.bin"),
106+
Model(name: "medium-q8_0", info: "(785 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q8_0.bin", filename: "medium-q8_0.bin"),
107+
Model(name: "medium.en", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin", filename: "medium.en.bin"),
108+
Model(name: "medium.en-q5_0", info: "(514 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin", filename: "medium.en-q5_0.bin"),
109+
Model(name: "medium.en-q8_0", info: "(785 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q8_0.bin", filename: "medium.en-q8_0.bin"),
110+
Model(name: "large-v1", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large.bin", filename: "large.bin"),
111+
Model(name: "large-v2", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2.bin", filename: "large-v2.bin"),
112+
Model(name: "large-v2-q5_0", info: "(1.1 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2-q5_0.bin", filename: "large-v2-q5_0.bin"),
113+
Model(name: "large-v2-q8_0", info: "(1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2-q8_0.bin", filename: "large-v2-q8_0.bin"),
114+
Model(name: "large-v3", info: "(F16, 2.9 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin", filename: "large-v3.bin"),
115+
Model(name: "large-v3-q5_0", info: "(1.1 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-q5_0.bin", filename: "large-v3-q5_0.bin"),
116+
Model(name: "large-v3-turbo", info: "(F16, 1.5 GiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin", filename: "large-v3-turbo.bin"),
117+
Model(name: "large-v3-turbo-q5_0", info: "(547 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo-q5_0.bin", filename: "large-v3-turbo-q5_0.bin"),
118+
Model(name: "large-v3-turbo-q8_0", info: "(834 MiB)", url: "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo-q8_0.bin", filename: "large-v3-turbo-q8_0.bin"),
119+
]
120+
121+
static func getDownloadedModels() -> [Model] {
122+
// Filter models that have been downloaded
123+
return models.filter {
124+
FileManager.default.fileExists(atPath: $0.fileURL.path())
125+
}
126+
}
127+
128+
func loadModel(model: Model) {
129+
Task {
130+
dismiss()
131+
whisperState.loadModel(path: model.fileURL)
132+
}
133+
}
134+
135+
var body: some View {
136+
List {
137+
Section(header: Text("Models")) {
138+
ForEach(ModelsView.models) { model in
139+
DownloadButton(model: model)
140+
.onLoad(perform: loadModel)
141+
}
142+
}
143+
}
144+
.listStyle(GroupedListStyle())
145+
.navigationBarTitle("Models", displayMode: .inline).toolbar {}
146+
}
42147
}
43148
}
149+
150+
//struct ContentView_Previews: PreviewProvider {
151+
// static var previews: some View {
152+
// ContentView()
153+
// }
154+
//}

0 commit comments

Comments
 (0)