Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
306A71502DC1DC3D00936B1F /* regex.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71492DC1DC3D00936B1F /* regex.cpp */; };
306A71512DC1DC3D00936B1F /* pre_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */; };
306A71522DC1DC3D00936B1F /* token_decoder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */; };
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3072D5222DC3EA280083FC83 /* Constants.swift */; };
F292B0752D88B0C200BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06F2D88B0C200BE6839 /* tiktoken.cpp */; };
F292B0762D88B0C200BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */; };
F292B0772D88B0C200BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */; };
Expand Down Expand Up @@ -147,6 +148,7 @@
306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = regex.cpp; path = src/regex.cpp; sourceTree = "<group>"; };
306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = std_regex.cpp; path = src/std_regex.cpp; sourceTree = "<group>"; };
306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = token_decoder.cpp; path = src/token_decoder.cpp; sourceTree = "<group>"; };
3072D5222DC3EA280083FC83 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = "<group>"; };
F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = "<group>"; };
F292B06F2D88B0C200BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = "<group>"; };
Expand Down Expand Up @@ -208,6 +210,7 @@
0324D6892BAACB6900DEF36F /* Application */ = {
isa = PBXGroup;
children = (
3072D5222DC3EA280083FC83 /* Constants.swift */,
0324D6802BAACB6900DEF36F /* App.swift */,
0324D6812BAACB6900DEF36F /* ContentView.swift */,
0324D6822BAACB6900DEF36F /* LogManager.swift */,
Expand Down Expand Up @@ -554,6 +557,7 @@
buildActionMask = 2147483647;
files = (
0324D6932BAACB6900DEF36F /* ResourceMonitor.swift in Sources */,
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */,
0324D68D2BAACB6900DEF36F /* LogManager.swift in Sources */,
0324D68E2BAACB6900DEF36F /* LogView.swift in Sources */,
0324D68F2BAACB6900DEF36F /* Message.swift in Sources */,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ struct ContentView: View {
case tokenizer
}

enum ModelType {
case llama
case llava
case qwen3

static func fromPath(_ path: String) -> ModelType {
let filename = (path as NSString).lastPathComponent.lowercased()
if filename.hasPrefix("llama") {
return .llama
} else if filename.hasPrefix("llava") {
return .llava
} else if filename.hasPrefix("qwen3") {
return .qwen3
}
print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, or qwen3")
exit(1)
}
}

private var placeholder: String {
resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..."
}
Expand Down Expand Up @@ -275,14 +294,14 @@ struct ContentView: View {
let seq_len = 768 // text: 256, vision: 768
let modelPath = resourceManager.modelPath
let tokenizerPath = resourceManager.tokenizerPath
let useLlama = modelPath.lowercased().contains("llama")
let modelType = ModelType.fromPath(modelPath)

prompt = ""
hideKeyboard()
showingSettings = false

messages.append(Message(text: text))
messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated))
messages.append(Message(type: modelType == .llama ? .llamagenerated : .llavagenerated))

runnerQueue.async {
defer {
Expand All @@ -292,14 +311,16 @@ struct ContentView: View {
}
}

if useLlama {
switch modelType {
case .llama, .qwen3:
runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath)
} else {
case .llava:
runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
}

guard !shouldStopGenerating else { return }
if useLlama {
switch modelType {
case .llama, .qwen3:
if let runner = runnerHolder.runner, !runner.isLoaded() {
var error: Error?
let startLoadTime = Date()
Expand Down Expand Up @@ -329,7 +350,7 @@ struct ContentView: View {
return
}
}
} else {
case .llava:
if let runner = runnerHolder.llavaRunner, !runner.isLoaded() {
var error: Error?
let startLoadTime = Date()
Expand Down Expand Up @@ -411,12 +432,19 @@ struct ContentView: View {
}
}
} else {
let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
let prompt: String
switch modelType {
case .qwen3:
prompt = String(format: Constants.qwen3PromptTemplate, text)
case .llama:
prompt = String(format: Constants.llama3PromptTemplate, text)
case .llava:
prompt = String(format: Constants.llama3PromptTemplate, text)
}

try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in
try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in

NSLog(">>> token={\(token)}")
if token != llama3_prompt {
if token != prompt {
// hack to fix the issue that extension/llm/runner/text_token_generator.h
// keeps generating after <|eot_id|>
if token == "<|eot_id|>" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import UIKit

enum MessageType {
case prompted
case llamagenerated
case llamagenerated // TODO: change this to to something more general, like "textgenerated".
case llavagenerated
case info
}
Expand Down
Loading