Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ struct ContentView: View {
@State private var isGenerating = false
@State private var shouldStopGenerating = false
@State private var shouldStopShowingToken = false
@State private var thinkingMode = false
@State private var showThinkingModeNotification = false
private let runnerQueue = DispatchQueue(label: "org.pytorch.executorch.llama")
@StateObject private var runnerHolder = RunnerHolder()
@StateObject private var resourceManager = ResourceManager()
Expand Down Expand Up @@ -119,107 +121,136 @@ struct ContentView: View {

var body: some View {
NavigationView {
VStack {
if showingSettings {
VStack(spacing: 20) {
HStack {
VStack(spacing: 10) {
Button(action: { pickerType = .model }) {
Label(modelTitle, systemImage: "doc")
.lineLimit(1)
.truncationMode(.middle)
.frame(maxWidth: 300, alignment: .leading)
}
Button(action: { pickerType = .tokenizer }) {
Label(tokenizerTitle, systemImage: "doc")
.lineLimit(1)
.truncationMode(.middle)
.frame(maxWidth: 300, alignment: .leading)
ZStack {
VStack {
if showingSettings {
VStack(spacing: 20) {
HStack {
VStack(spacing: 10) {
Button(action: { pickerType = .model }) {
Label(modelTitle, systemImage: "doc")
.lineLimit(1)
.truncationMode(.middle)
.frame(maxWidth: 300, alignment: .leading)
}
Button(action: { pickerType = .tokenizer }) {
Label(tokenizerTitle, systemImage: "doc")
.lineLimit(1)
.truncationMode(.middle)
.frame(maxWidth: 300, alignment: .leading)
}
}
.padding()
.background(Color.gray.opacity(0.1))
.cornerRadius(10)
.fixedSize(horizontal: true, vertical: false)
Spacer()
}
.padding()
.background(Color.gray.opacity(0.1))
.cornerRadius(10)
.fixedSize(horizontal: true, vertical: false)
Spacer()
}
.padding()
}
}

MessageListView(messages: $messages)
.simultaneousGesture(
DragGesture().onChanged { value in
if value.translation.height > 10 {
hideKeyboard()
MessageListView(messages: $messages)
.simultaneousGesture(
DragGesture().onChanged { value in
if value.translation.height > 10 {
hideKeyboard()
}
showingSettings = false
textFieldFocused = false
}
)
.onTapGesture {
showingSettings = false
textFieldFocused = false
}
)
.onTapGesture {
showingSettings = false
textFieldFocused = false
}

HStack {
Button(action: {
imagePickerSourceType = .photoLibrary
isImagePickerPresented = true
}) {
Image(systemName: "photo.on.rectangle")
.resizable()
.scaledToFit()
.frame(width: 24, height: 24)
}
.background(Color.clear)
.cornerRadius(8)

Button(action: {
if UIImagePickerController.isSourceTypeAvailable(.camera) {
imagePickerSourceType = .camera
HStack {
Button(action: {
imagePickerSourceType = .photoLibrary
isImagePickerPresented = true
} else {
print("Camera not available")
}) {
Image(systemName: "photo.on.rectangle")
.resizable()
.scaledToFit()
.frame(width: 24, height: 24)
}
}) {
Image(systemName: "camera")
.resizable()
.scaledToFit()
.frame(width: 24, height: 24)
}
.background(Color.clear)
.cornerRadius(8)
.background(Color.clear)
.cornerRadius(8)

TextField(placeholder, text: $prompt, axis: .vertical)
.padding(8)
.background(Color.gray.opacity(0.1))
.cornerRadius(20)
.lineLimit(1...10)
.overlay(
RoundedRectangle(cornerRadius: 20)
.stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1)
)
.disabled(!isInputEnabled)
.focused($textFieldFocused)
.onAppear { textFieldFocused = false }
.onTapGesture {
showingSettings = false
Button(action: {
if UIImagePickerController.isSourceTypeAvailable(.camera) {
imagePickerSourceType = .camera
isImagePickerPresented = true
} else {
print("Camera not available")
}
}) {
Image(systemName: "camera")
.resizable()
.scaledToFit()
.frame(width: 24, height: 24)
}
.background(Color.clear)
.cornerRadius(8)

if resourceManager.isModelValid && ModelType.fromPath(resourceManager.modelPath) == .qwen3 {
Button(action: {
thinkingMode.toggle()
showThinkingModeNotification = true
DispatchQueue.main.asyncAfter(deadline: .now() + 3) {
showThinkingModeNotification = false
}
}) {
Image(systemName: "brain")
.resizable()
.scaledToFit()
.frame(width: 24, height: 24)
.foregroundColor(thinkingMode ? .blue : .gray)
}
.background(Color.clear)
.cornerRadius(8)
}

Button(action: isGenerating ? stop : generate) {
Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(height: 28)
TextField(placeholder, text: $prompt, axis: .vertical)
.padding(8)
.background(Color.gray.opacity(0.1))
.cornerRadius(20)
.lineLimit(1...10)
.overlay(
RoundedRectangle(cornerRadius: 20)
.stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1)
)
.disabled(!isInputEnabled)
.focused($textFieldFocused)
.onAppear { textFieldFocused = false }
.onTapGesture {
showingSettings = false
}

Button(action: isGenerating ? stop : generate) {
Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(height: 28)
}
.disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty))
}
.disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty))
.padding([.leading, .trailing, .bottom], 10)
}
.padding([.leading, .trailing, .bottom], 10)
.sheet(isPresented: $isImagePickerPresented, onDismiss: addSelectedImageMessage) {
ImagePicker(selectedImage: $selectedImage, sourceType: imagePickerSourceType)
.id(imagePickerSourceType.rawValue)
}

if showThinkingModeNotification {
Text(thinkingMode ? "Thinking mode enabled" : "Thinking mode disabled")
.padding(8)
.background(Color(UIColor.secondarySystemBackground))
.cornerRadius(8)
.transition(.opacity)
.animation(.easeInOut(duration: 0.2), value: showThinkingModeNotification)
}
}
.navigationBarTitle(title, displayMode: .inline)
.navigationBarItems(
Expand Down Expand Up @@ -435,7 +466,10 @@ struct ContentView: View {
let prompt: String
switch modelType {
case .qwen3:
prompt = String(format: Constants.qwen3PromptTemplate, text)
let basePrompt = String(format: Constants.qwen3PromptTemplate, text)
// If thinking mode is enabled for Qwen, don't skip the <think></think> special tokens
// and have them be generated.
prompt = thinkingMode ? basePrompt.replacingOccurrences(of: "<think>\n\n</think>\n\n\n", with: "") : basePrompt
case .llama:
prompt = String(format: Constants.llama3PromptTemplate, text)
case .llava:
Expand All @@ -445,12 +479,45 @@ struct ContentView: View {
try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in

if token != prompt {
// hack to fix the issue that extension/llm/runner/text_token_generator.h
// keeps generating after <|eot_id|>
if token == "<|eot_id|>" {
// hack to fix the issue that extension/llm/runner/text_token_generator.h
// keeps generating after <|eot_id|>
shouldStopShowingToken = true
} else if token == "<|im_end|>" {
// Qwen3 specific token.
// Skip.
} else if token == "<think>" {
// Qwen3 specific token.
let textToFlush = tokens.joined()
let flushedTokenCount = tokens.count
tokens = []
DispatchQueue.main.async {
var message = messages.removeLast()
message.text += textToFlush
message.text += message.text.isEmpty ? "Thinking...\n\n" : "\n\nThinking...\n\n"
message.format = .italic
message.tokenCount += flushedTokenCount + 1 // + 1 for the start thinking token.
message.dateUpdated = Date()
messages.append(message)
}
} else if token == "</think>" {
// Qwen3 specific token.
let textToFlush = tokens.joined()
let flushedTokenCount = tokens.count
tokens = []
DispatchQueue.main.async {
var message = messages.removeLast()
message.text += textToFlush
message.text += "\n\nFinished thinking.\n\n"
message.format = .italic
message.tokenCount += flushedTokenCount + 1 // + 1 for the end thinking token.
message.dateUpdated = Date()
messages.append(message)
}
} else {
tokens.append(token.trimmingCharacters(in: .newlines))
// Flush tokens in groups of 3 so that it's closer to whole words being generated
// rather than parts of words (tokens).
if tokens.count > 2 {
let text = tokens.joined()
let count = tokens.count
Expand Down
Loading