From b496c8ba1664aaa4ccd746e1236899c1348fa02b Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 1 May 2025 10:46:36 -0700 Subject: [PATCH] Switch between model types based on model name --- .../LLaMA/LLaMA.xcodeproj/project.pbxproj | 18 ++++--- .../LLaMA/LLaMA/Application/ContentView.swift | 48 +++++++++++++++---- .../LLaMA/LLaMA/Application/Message.swift | 2 +- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj index dc25940c8ad..ef9c5ebc495 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj @@ -62,6 +62,7 @@ 306A71502DC1DC3D00936B1F /* regex.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71492DC1DC3D00936B1F /* regex.cpp */; }; 306A71512DC1DC3D00936B1F /* pre_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */; }; 306A71522DC1DC3D00936B1F /* token_decoder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */; }; + 3072D5232DC3EA280083FC83 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3072D5222DC3EA280083FC83 /* Constants.swift */; }; F292B0752D88B0C200BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06F2D88B0C200BE6839 /* tiktoken.cpp */; }; F292B0762D88B0C200BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */; }; F292B0772D88B0C200BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */; }; @@ -140,13 +141,14 @@ 306A713A2DC1DC0F00936B1F /* std_regex.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = std_regex.h; sourceTree = ""; }; 306A713B2DC1DC0F00936B1F /* string_integer_map.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = string_integer_map.h; sourceTree = ""; }; 306A713C2DC1DC0F00936B1F /* token_decoder.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = token_decoder.h; sourceTree = ""; }; - 306A71452DC1DC3D00936B1F /* hf_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = hf_tokenizer.cpp; sourceTree = ""; }; - 306A71462DC1DC3D00936B1F /* pcre2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = pcre2_regex.cpp; sourceTree = ""; }; - 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = pre_tokenizer.cpp; sourceTree = ""; }; - 306A71482DC1DC3D00936B1F /* re2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = re2_regex.cpp; sourceTree = ""; }; - 306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = regex.cpp; sourceTree = ""; }; - 306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = std_regex.cpp; sourceTree = ""; }; - 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = token_decoder.cpp; sourceTree = ""; }; + 306A71452DC1DC3D00936B1F /* hf_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = hf_tokenizer.cpp; path = src/hf_tokenizer.cpp; sourceTree = ""; }; + 306A71462DC1DC3D00936B1F /* pcre2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = pcre2_regex.cpp; path = src/pcre2_regex.cpp; sourceTree = ""; }; + 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = pre_tokenizer.cpp; path = src/pre_tokenizer.cpp; sourceTree = ""; }; + 306A71482DC1DC3D00936B1F /* re2_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = re2_regex.cpp; path = src/re2_regex.cpp; sourceTree = ""; }; + 306A71492DC1DC3D00936B1F /* regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = regex.cpp; path = src/regex.cpp; sourceTree = ""; }; + 306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = std_regex.cpp; path = src/std_regex.cpp; sourceTree = ""; }; + 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = token_decoder.cpp; path = src/token_decoder.cpp; sourceTree = ""; }; + 3072D5222DC3EA280083FC83 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = ""; }; F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = ""; }; F292B06F2D88B0C200BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = ""; }; @@ -208,6 +210,7 @@ 0324D6892BAACB6900DEF36F /* Application */ = { isa = PBXGroup; children = ( + 3072D5222DC3EA280083FC83 /* Constants.swift */, 0324D6802BAACB6900DEF36F /* App.swift */, 0324D6812BAACB6900DEF36F /* ContentView.swift */, 0324D6822BAACB6900DEF36F /* LogManager.swift */, @@ -554,6 +557,7 @@ buildActionMask = 2147483647; files = ( 0324D6932BAACB6900DEF36F /* ResourceMonitor.swift in Sources */, + 3072D5232DC3EA280083FC83 /* Constants.swift in Sources */, 0324D68D2BAACB6900DEF36F /* LogManager.swift in Sources */, 0324D68E2BAACB6900DEF36F /* LogView.swift in Sources */, 0324D68F2BAACB6900DEF36F /* Message.swift in Sources */, diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift index f315c891337..84ea05731c5 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift @@ -80,6 +80,25 @@ struct ContentView: View { case tokenizer } + enum ModelType { + case llama + case llava + case qwen3 + + static func fromPath(_ path: String) -> ModelType { + let filename = (path as NSString).lastPathComponent.lowercased() + if filename.hasPrefix("llama") { + return .llama + } else if filename.hasPrefix("llava") { + return .llava + } else if filename.hasPrefix("qwen3") { + return .qwen3 + } + print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, or qwen3") + exit(1) + } + } + private var placeholder: String { resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..." } @@ -275,14 +294,14 @@ struct ContentView: View { let seq_len = 768 // text: 256, vision: 768 let modelPath = resourceManager.modelPath let tokenizerPath = resourceManager.tokenizerPath - let useLlama = modelPath.lowercased().contains("llama") + let modelType = ModelType.fromPath(modelPath) prompt = "" hideKeyboard() showingSettings = false messages.append(Message(text: text)) - messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated)) + messages.append(Message(type: modelType == .llama ? .llamagenerated : .llavagenerated)) runnerQueue.async { defer { @@ -292,14 +311,16 @@ struct ContentView: View { } } - if useLlama { + switch modelType { + case .llama, .qwen3: runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath) - } else { + case .llava: runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath) } guard !shouldStopGenerating else { return } - if useLlama { + switch modelType { + case .llama, .qwen3: if let runner = runnerHolder.runner, !runner.isLoaded() { var error: Error? let startLoadTime = Date() @@ -329,7 +350,7 @@ struct ContentView: View { return } } - } else { + case .llava: if let runner = runnerHolder.llavaRunner, !runner.isLoaded() { var error: Error? let startLoadTime = Date() @@ -411,12 +432,19 @@ struct ContentView: View { } } } else { - let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + let prompt: String + switch modelType { + case .qwen3: + prompt = String(format: Constants.qwen3PromptTemplate, text) + case .llama: + prompt = String(format: Constants.llama3PromptTemplate, text) + case .llava: + prompt = String(format: Constants.llama3PromptTemplate, text) + } - try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in + try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in - NSLog(">>> token={\(token)}") - if token != llama3_prompt { + if token != prompt { // hack to fix the issue that extension/llm/runner/text_token_generator.h // keeps generating after <|eot_id|> if token == "<|eot_id|>" { diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift index 400941f496a..34ed0d7e933 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Message.swift @@ -10,7 +10,7 @@ import UIKit enum MessageType { case prompted - case llamagenerated + case llamagenerated // TODO: change this to to something more general, like "textgenerated". case llavagenerated case info }