Skip to content

Commit 7b55344

Browse files
committed
Switch between model types based on model name
1 parent 0de4f59 commit 7b55344

File tree

3 files changed

+50
-18
lines changed

3 files changed

+50
-18
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
306A71502DC1DC3D00936B1F /* regex.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71492DC1DC3D00936B1F /* regex.cpp */; };
6363
306A71512DC1DC3D00936B1F /* pre_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */; };
6464
306A71522DC1DC3D00936B1F /* token_decoder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */; };
65+
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3072D5222DC3EA280083FC83 /* Constants.swift */; };
6566
F292B0752D88B0C200BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06F2D88B0C200BE6839 /* tiktoken.cpp */; };
6667
F292B0762D88B0C200BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */; };
6768
F292B0772D88B0C200BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */; };
@@ -140,13 +141,14 @@
140141
306A713A2DC1DC0F00936B1F /* std_regex.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = std_regex.h; sourceTree = "<group>"; };
141142
306A713B2DC1DC0F00936B1F /* string_integer_map.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = string_integer_map.h; sourceTree = "<group>"; };
142143
306A713C2DC1DC0F00936B1F /* token_decoder.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = token_decoder.h; sourceTree = "<group>"; };
143-
306A71452DC1DC3D00936B1F /* hf_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = hf_tokenizer.cpp; sourceTree = "<group>"; };
144-
306A71462DC1DC3D00936B1F /* pcre2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = pcre2_regex.cpp; sourceTree = "<group>"; };
145-
306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = pre_tokenizer.cpp; sourceTree = "<group>"; };
146-
306A71482DC1DC3D00936B1F /* re2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = re2_regex.cpp; sourceTree = "<group>"; };
147-
306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = regex.cpp; sourceTree = "<group>"; };
148-
306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = std_regex.cpp; sourceTree = "<group>"; };
149-
306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = token_decoder.cpp; sourceTree = "<group>"; };
144+
306A71452DC1DC3D00936B1F /* hf_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = hf_tokenizer.cpp; path = src/hf_tokenizer.cpp; sourceTree = "<group>"; };
145+
306A71462DC1DC3D00936B1F /* pcre2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = pcre2_regex.cpp; path = src/pcre2_regex.cpp; sourceTree = "<group>"; };
146+
306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = pre_tokenizer.cpp; path = src/pre_tokenizer.cpp; sourceTree = "<group>"; };
147+
306A71482DC1DC3D00936B1F /* re2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = re2_regex.cpp; path = src/re2_regex.cpp; sourceTree = "<group>"; };
148+
306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = regex.cpp; path = src/regex.cpp; sourceTree = "<group>"; };
149+
306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = std_regex.cpp; path = src/std_regex.cpp; sourceTree = "<group>"; };
150+
306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = token_decoder.cpp; path = src/token_decoder.cpp; sourceTree = "<group>"; };
151+
3072D5222DC3EA280083FC83 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
150152
F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = "<group>"; };
151153
F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = "<group>"; };
152154
F292B06F2D88B0C200BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = "<group>"; };
@@ -208,6 +210,7 @@
208210
0324D6892BAACB6900DEF36F /* Application */ = {
209211
isa = PBXGroup;
210212
children = (
213+
3072D5222DC3EA280083FC83 /* Constants.swift */,
211214
0324D6802BAACB6900DEF36F /* App.swift */,
212215
0324D6812BAACB6900DEF36F /* ContentView.swift */,
213216
0324D6822BAACB6900DEF36F /* LogManager.swift */,
@@ -554,6 +557,7 @@
554557
buildActionMask = 2147483647;
555558
files = (
556559
0324D6932BAACB6900DEF36F /* ResourceMonitor.swift in Sources */,
560+
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */,
557561
0324D68D2BAACB6900DEF36F /* LogManager.swift in Sources */,
558562
0324D68E2BAACB6900DEF36F /* LogView.swift in Sources */,
559563
0324D68F2BAACB6900DEF36F /* Message.swift in Sources */,

examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@ struct ContentView: View {
8282
case tokenizer
8383
}
8484

85+
enum ModelType {
86+
case llama
87+
case llava
88+
case qwen3
89+
90+
static func fromPath(_ path: String) -> ModelType {
91+
let filename = (path as NSString).lastPathComponent.lowercased()
92+
if filename.hasPrefix("llama") {
93+
return .llama
94+
} else if filename.hasPrefix("llava") {
95+
return .llava
96+
} else if filename.hasPrefix("qwen3") {
97+
return .qwen3
98+
}
99+
print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, or qwen3")
100+
exit(1)
101+
}
102+
}
103+
85104
private var placeholder: String {
86105
resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..."
87106
}
@@ -304,14 +323,14 @@ struct ContentView: View {
304323
let seq_len = 768 // text: 256, vision: 768
305324
let modelPath = resourceManager.modelPath
306325
let tokenizerPath = resourceManager.tokenizerPath
307-
let useLlama = modelPath.lowercased().contains("llama")
326+
let modelType = ModelType.fromPath(modelPath)
308327

309328
prompt = ""
310329
hideKeyboard()
311330
showingSettings = false
312331

313332
messages.append(Message(text: text))
314-
messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated))
333+
messages.append(Message(type: modelType == .llama ? .llamagenerated : .llavagenerated))
315334

316335
runnerQueue.async {
317336
defer {
@@ -321,14 +340,16 @@ struct ContentView: View {
321340
}
322341
}
323342

324-
if useLlama {
343+
switch modelType {
344+
case .llama, .qwen3:
325345
runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath)
326-
} else {
346+
case .llava:
327347
runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
328348
}
329349

330350
guard !shouldStopGenerating else { return }
331-
if useLlama {
351+
switch modelType {
352+
case .llama, .qwen3:
332353
if let runner = runnerHolder.runner, !runner.isLoaded() {
333354
var error: Error?
334355
let startLoadTime = Date()
@@ -358,7 +379,7 @@ struct ContentView: View {
358379
return
359380
}
360381
}
361-
} else {
382+
case .llava:
362383
if let runner = runnerHolder.llavaRunner, !runner.isLoaded() {
363384
var error: Error?
364385
let startLoadTime = Date()
@@ -440,12 +461,19 @@ struct ContentView: View {
440461
}
441462
}
442463
} else {
443-
let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
464+
let prompt: String
465+
switch modelType {
466+
case .qwen3:
467+
prompt = String(format: Constants.qwen3PromptTemplate, text)
468+
case .llama:
469+
prompt = String(format: Constants.llama3PromptTemplate, text)
470+
case .llava:
471+
prompt = String(format: Constants.llama3PromptTemplate, text)
472+
}
444473

445-
try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in
474+
try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in
446475

447-
NSLog(">>> token={\(token)}")
448-
if token != llama3_prompt {
476+
if token != prompt {
449477
// hack to fix the issue that extension/llm/runner/text_token_generator.h
450478
// keeps generating after <|eot_id|>
451479
if token == "<|eot_id|>" {

examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import UIKit
1010

1111
enum MessageType {
1212
case prompted
13-
case llamagenerated
13+
case llamagenerated // TODO: change this to to something more general, like "textgenerated".
1414
case llavagenerated
1515
case info
1616
}

0 commit comments

Comments
 (0)