diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift index 84ea05731c5..cc1567b3cde 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift @@ -62,6 +62,8 @@ struct ContentView: View { @State private var isGenerating = false @State private var shouldStopGenerating = false @State private var shouldStopShowingToken = false + @State private var thinkingMode = false + @State private var showThinkingModeNotification = false private let runnerQueue = DispatchQueue(label: "org.pytorch.executorch.llama") @StateObject private var runnerHolder = RunnerHolder() @StateObject private var resourceManager = ResourceManager() @@ -119,107 +121,136 @@ struct ContentView: View { var body: some View { NavigationView { - VStack { - if showingSettings { - VStack(spacing: 20) { - HStack { - VStack(spacing: 10) { - Button(action: { pickerType = .model }) { - Label(modelTitle, systemImage: "doc") - .lineLimit(1) - .truncationMode(.middle) - .frame(maxWidth: 300, alignment: .leading) - } - Button(action: { pickerType = .tokenizer }) { - Label(tokenizerTitle, systemImage: "doc") - .lineLimit(1) - .truncationMode(.middle) - .frame(maxWidth: 300, alignment: .leading) + ZStack { + VStack { + if showingSettings { + VStack(spacing: 20) { + HStack { + VStack(spacing: 10) { + Button(action: { pickerType = .model }) { + Label(modelTitle, systemImage: "doc") + .lineLimit(1) + .truncationMode(.middle) + .frame(maxWidth: 300, alignment: .leading) + } + Button(action: { pickerType = .tokenizer }) { + Label(tokenizerTitle, systemImage: "doc") + .lineLimit(1) + .truncationMode(.middle) + .frame(maxWidth: 300, alignment: .leading) + } } + .padding() + .background(Color.gray.opacity(0.1)) + .cornerRadius(10) + .fixedSize(horizontal: true, vertical: false) + Spacer() } .padding() - .background(Color.gray.opacity(0.1)) - .cornerRadius(10) - .fixedSize(horizontal: true, vertical: false) - Spacer() } - .padding() } - } - MessageListView(messages: $messages) - .simultaneousGesture( - DragGesture().onChanged { value in - if value.translation.height > 10 { - hideKeyboard() + MessageListView(messages: $messages) + .simultaneousGesture( + DragGesture().onChanged { value in + if value.translation.height > 10 { + hideKeyboard() + } + showingSettings = false + textFieldFocused = false } + ) + .onTapGesture { showingSettings = false textFieldFocused = false } - ) - .onTapGesture { - showingSettings = false - textFieldFocused = false - } - HStack { - Button(action: { - imagePickerSourceType = .photoLibrary - isImagePickerPresented = true - }) { - Image(systemName: "photo.on.rectangle") - .resizable() - .scaledToFit() - .frame(width: 24, height: 24) - } - .background(Color.clear) - .cornerRadius(8) - - Button(action: { - if UIImagePickerController.isSourceTypeAvailable(.camera) { - imagePickerSourceType = .camera + HStack { + Button(action: { + imagePickerSourceType = .photoLibrary isImagePickerPresented = true - } else { - print("Camera not available") + }) { + Image(systemName: "photo.on.rectangle") + .resizable() + .scaledToFit() + .frame(width: 24, height: 24) } - }) { - Image(systemName: "camera") - .resizable() - .scaledToFit() - .frame(width: 24, height: 24) - } - .background(Color.clear) - .cornerRadius(8) + .background(Color.clear) + .cornerRadius(8) - TextField(placeholder, text: $prompt, axis: .vertical) - .padding(8) - .background(Color.gray.opacity(0.1)) - .cornerRadius(20) - .lineLimit(1...10) - .overlay( - RoundedRectangle(cornerRadius: 20) - .stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1) - ) - .disabled(!isInputEnabled) - .focused($textFieldFocused) - .onAppear { textFieldFocused = false } - .onTapGesture { - showingSettings = false + Button(action: { + if UIImagePickerController.isSourceTypeAvailable(.camera) { + imagePickerSourceType = .camera + isImagePickerPresented = true + } else { + print("Camera not available") + } + }) { + Image(systemName: "camera") + .resizable() + .scaledToFit() + .frame(width: 24, height: 24) + } + .background(Color.clear) + .cornerRadius(8) + + if resourceManager.isModelValid && ModelType.fromPath(resourceManager.modelPath) == .qwen3 { + Button(action: { + thinkingMode.toggle() + showThinkingModeNotification = true + DispatchQueue.main.asyncAfter(deadline: .now() + 3) { + showThinkingModeNotification = false + } + }) { + Image(systemName: "brain") + .resizable() + .scaledToFit() + .frame(width: 24, height: 24) + .foregroundColor(thinkingMode ? .blue : .gray) + } + .background(Color.clear) + .cornerRadius(8) } - Button(action: isGenerating ? stop : generate) { - Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill") - .resizable() - .aspectRatio(contentMode: .fit) - .frame(height: 28) + TextField(placeholder, text: $prompt, axis: .vertical) + .padding(8) + .background(Color.gray.opacity(0.1)) + .cornerRadius(20) + .lineLimit(1...10) + .overlay( + RoundedRectangle(cornerRadius: 20) + .stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1) + ) + .disabled(!isInputEnabled) + .focused($textFieldFocused) + .onAppear { textFieldFocused = false } + .onTapGesture { + showingSettings = false + } + + Button(action: isGenerating ? stop : generate) { + Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill") + .resizable() + .aspectRatio(contentMode: .fit) + .frame(height: 28) + } + .disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty)) } - .disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty)) + .padding([.leading, .trailing, .bottom], 10) } - .padding([.leading, .trailing, .bottom], 10) .sheet(isPresented: $isImagePickerPresented, onDismiss: addSelectedImageMessage) { ImagePicker(selectedImage: $selectedImage, sourceType: imagePickerSourceType) .id(imagePickerSourceType.rawValue) } + + if showThinkingModeNotification { + Text(thinkingMode ? "Thinking mode enabled" : "Thinking mode disabled") + .padding(8) + .background(Color(UIColor.secondarySystemBackground)) + .cornerRadius(8) + .transition(.opacity) + .animation(.easeInOut(duration: 0.2), value: showThinkingModeNotification) + } } .navigationBarTitle(title, displayMode: .inline) .navigationBarItems( @@ -435,7 +466,10 @@ struct ContentView: View { let prompt: String switch modelType { case .qwen3: - prompt = String(format: Constants.qwen3PromptTemplate, text) + let basePrompt = String(format: Constants.qwen3PromptTemplate, text) + // If thinking mode is enabled for Qwen, don't skip the special tokens + // and have them be generated. + prompt = thinkingMode ? basePrompt.replacingOccurrences(of: "\n\n\n\n\n", with: "") : basePrompt case .llama: prompt = String(format: Constants.llama3PromptTemplate, text) case .llava: @@ -445,12 +479,45 @@ struct ContentView: View { try runnerHolder.runner?.generate(prompt, sequenceLength: seq_len) { token in if token != prompt { - // hack to fix the issue that extension/llm/runner/text_token_generator.h - // keeps generating after <|eot_id|> if token == "<|eot_id|>" { + // hack to fix the issue that extension/llm/runner/text_token_generator.h + // keeps generating after <|eot_id|> shouldStopShowingToken = true + } else if token == "<|im_end|>" { + // Qwen3 specific token. + // Skip. + } else if token == "" { + // Qwen3 specific token. + let textToFlush = tokens.joined() + let flushedTokenCount = tokens.count + tokens = [] + DispatchQueue.main.async { + var message = messages.removeLast() + message.text += textToFlush + message.text += message.text.isEmpty ? "Thinking...\n\n" : "\n\nThinking...\n\n" + message.format = .italic + message.tokenCount += flushedTokenCount + 1 // + 1 for the start thinking token. + message.dateUpdated = Date() + messages.append(message) + } + } else if token == "" { + // Qwen3 specific token. + let textToFlush = tokens.joined() + let flushedTokenCount = tokens.count + tokens = [] + DispatchQueue.main.async { + var message = messages.removeLast() + message.text += textToFlush + message.text += "\n\nFinished thinking.\n\n" + message.format = .italic + message.tokenCount += flushedTokenCount + 1 // + 1 for the end thinking token. + message.dateUpdated = Date() + messages.append(message) + } } else { tokens.append(token.trimmingCharacters(in: .newlines)) + // Flush tokens in groups of 3 so that it's closer to whole words being generated + // rather than parts of words (tokens). if tokens.count > 2 { let text = tokens.joined() let count = tokens.count