Skip to content

Commit 1b1d999

Browse files
committed
Refactor chat UI and token generation logic
1 parent 05f7a72 commit 1b1d999

File tree

5 files changed

+220
-46
lines changed

5 files changed

+220
-46
lines changed
Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,124 @@
1-
// Copyright (c) Microsoft Corporation. All rights reserved.
2-
// Licensed under the MIT License.
3-
41
import SwiftUI
52

3+
4+
struct Message: Identifiable {
5+
let id = UUID()
6+
var text: String
7+
let isUser: Bool
8+
}
9+
610
struct ContentView: View {
7-
@ObservedObject var tokenUpdater = SharedTokenUpdater.shared
11+
@State private var userInput: String = ""
12+
@State private var messages: [Message] = [] // Store chat messages locally
13+
@State private var isGenerating: Bool = false // Track token generation state
14+
@State private var stats: String = "" // token genetation stats
815

916
var body: some View {
1017
VStack {
18+
// ChatBubbles
1119
ScrollView {
12-
VStack(alignment: .leading) {
13-
ForEach(tokenUpdater.decodedTokens, id: \.self) { token in
14-
Text(token)
15-
.padding(.horizontal, 5)
20+
VStack(alignment: .leading, spacing: 20) {
21+
ForEach(messages) { message in
22+
ChatBubble(text: message.text, isUser: message.isUser)
23+
.padding(.horizontal, 20)
24+
}
25+
if !stats.isEmpty {
26+
Text(stats)
27+
.font(.footnote)
28+
.foregroundColor(.gray)
29+
.padding(.horizontal, 20)
30+
.padding(.top, 5)
31+
.multilineTextAlignment(.center)
1632
}
1733
}
18-
.padding()
34+
.padding(.top, 20)
1935
}
20-
Button("Generate Tokens") {
21-
DispatchQueue.global(qos: .background).async {
22-
// TODO: add user prompt question UI
23-
GenAIGenerator.generate("Who is the current US president?");
36+
37+
38+
// User input
39+
HStack {
40+
TextField("Type your message...", text: $userInput)
41+
.padding()
42+
.background(Color(.systemGray6))
43+
.cornerRadius(20)
44+
.padding(.horizontal)
45+
46+
Button(action: {
47+
// Check for non-empty input
48+
guard !userInput.trimmingCharacters(in: .whitespaces).isEmpty else { return }
49+
50+
messages.append(Message(text: userInput, isUser: true))
51+
messages.append(Message(text: "", isUser: false)) // Placeholder for AI response
52+
53+
54+
// clear previously generated tokens
55+
SharedTokenUpdater.shared.clearTokens()
56+
57+
let prompt = userInput
58+
userInput = ""
59+
isGenerating = true
60+
61+
62+
DispatchQueue.global(qos: .background).async {
63+
GenAIGenerator.generate(prompt)
64+
}
65+
}) {
66+
Image(systemName: "paperplane.fill")
67+
.foregroundColor(.white)
68+
.padding()
69+
.background(isGenerating ? Color.gray : Color.pastelGreen)
70+
.clipShape(Circle())
71+
.padding(.trailing, 10)
2472
}
73+
.disabled(isGenerating)
74+
}
75+
.padding(.bottom, 20)
76+
}
77+
.background(Color(.systemGroupedBackground))
78+
.edgesIgnoringSafeArea(.bottom)
79+
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationCompleted"))) { _ in
80+
isGenerating = false // Re-enable the button when token generation is complete
81+
}
82+
.onReceive(SharedTokenUpdater.shared.$decodedTokens) { tokens in
83+
// update model response
84+
if let lastIndex = messages.lastIndex(where: { !$0.isUser }) {
85+
let combinedText = tokens.joined(separator: "")
86+
messages[lastIndex].text = combinedText
87+
}
88+
}
89+
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in
90+
if let userInfo = notification.userInfo,
91+
let totalTime = userInfo["totalTime"] as? Int,
92+
let firstTokenTime = userInfo["firstTokenTime"] as? Int,
93+
let tokenCount = userInfo["tokenCount"] as? Int {
94+
stats = "Generated \(tokenCount) tokens in \(totalTime) ms. First token in \(firstTokenTime) ms."
95+
}
96+
}
97+
}
98+
}
99+
100+
struct ChatBubble: View {
101+
var text: String
102+
var isUser: Bool
103+
104+
var body: some View {
105+
HStack {
106+
if isUser {
107+
Spacer()
108+
Text(text)
109+
.padding()
110+
.background(Color.pastelGreen)
111+
.foregroundColor(.white)
112+
.cornerRadius(25)
113+
.padding(.horizontal, 10)
114+
} else {
115+
Text(text)
116+
.padding()
117+
.background(Color(.systemGray5))
118+
.foregroundColor(.black)
119+
.cornerRadius(25)
120+
.padding(.horizontal, 20)
121+
Spacer()
25122
}
26123
}
27124
}
@@ -32,3 +129,8 @@ struct ContentView_Previews: PreviewProvider {
32129
ContentView()
33130
}
34131
}
132+
133+
// Extension for a pastel green color
134+
extension Color {
135+
static let pastelGreen = Color(red: 0.6, green: 0.9, blue: 0.6)
136+
}

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.mm

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,109 @@
55
#include "LocalLLM-Swift.h"
66
#include "ort_genai.h"
77
#include "ort_genai_c.h"
8-
8+
#include <chrono>
99

1010
@implementation GenAIGenerator
1111

12-
+ (void)generate:(nonnull NSString*)input_user_question {
13-
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
14-
const char* modelPath = llmPath.cString;
15-
16-
auto model = OgaModel::Create(modelPath);
17-
auto tokenizer = OgaTokenizer::Create(*model);
18-
19-
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
20-
const char* prompt = [promptString UTF8String];
21-
22-
auto sequences = OgaSequences::Create();
23-
tokenizer->Encode(prompt, *sequences);
12+
typedef std::chrono::high_resolution_clock Clock;
13+
typedef std::chrono::time_point<Clock> TimePoint;
2414

25-
auto params = OgaGeneratorParams::Create(*model);
26-
params->SetSearchOption("max_length", 200);
27-
params->SetInputSequences(*sequences);
28-
29-
// Streaming Output to generate token by token
30-
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
31-
32-
auto generator = OgaGenerator::Create(*model, *params);
15+
+ (void)generate:(nonnull NSString*)input_user_question {
16+
NSLog(@"Starting token generation...");
17+
18+
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
19+
const char* modelPath = llmPath.cString;
20+
21+
// Log model creation
22+
NSLog(@"Creating model ...");
23+
auto model = OgaModel::Create(modelPath);
24+
if (!model) {
25+
NSLog(@"Failed to create model.");
26+
return;
27+
}
28+
29+
NSLog(@"Creating tokenizer...");
30+
auto tokenizer = OgaTokenizer::Create(*model);
31+
if (!tokenizer) {
32+
NSLog(@"Failed to create tokenizer.");
33+
return;
34+
}
35+
36+
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
37+
38+
// Construct the prompt
39+
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
40+
const char* prompt = [promptString UTF8String];
41+
42+
NSLog(@"Encoding prompt...");
43+
auto sequences = OgaSequences::Create();
44+
tokenizer->Encode(prompt, *sequences);
45+
46+
// Log parameters
47+
NSLog(@"Setting generator parameters...");
48+
auto params = OgaGeneratorParams::Create(*model);
49+
params->SetSearchOption("max_length", 200);
50+
params->SetInputSequences(*sequences);
51+
52+
NSLog(@"Creating generator...");
53+
auto generator = OgaGenerator::Create(*model, *params);
54+
55+
bool isFirstToken = true;
56+
TimePoint startTime = Clock::now();
57+
TimePoint firstTokenTime;
58+
int tokenCount = 0;
59+
60+
NSLog(@"Starting token generation loop...");
61+
while (!generator->IsDone()) {
62+
generator->ComputeLogits();
63+
generator->GenerateNextToken();
64+
65+
if (isFirstToken) {
66+
NSLog(@"First token generated.");
67+
firstTokenTime = Clock::now();
68+
isFirstToken = false;
69+
}
70+
71+
// Get the sequence data
72+
const int32_t* seq = generator->GetSequenceData(0);
73+
size_t seq_len = generator->GetSequenceCount(0);
74+
75+
// Decode the new token
76+
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
77+
78+
// Check for decoding failure
79+
if (!decode_tokens) {
80+
NSLog(@"Token decoding failed.");
81+
break;
82+
}
83+
84+
NSLog(@"Decoded token: %s", decode_tokens);
85+
tokenCount++;
86+
87+
// Convert token to NSString and update UI on the main thread
88+
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
89+
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
90+
}
3391

34-
while (!generator->IsDone()) {
35-
generator->ComputeLogits();
36-
generator->GenerateNextToken();
3792

38-
const int32_t* seq = generator->GetSequenceData(0);
39-
size_t seq_len = generator->GetSequenceCount(0);
40-
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
93+
TimePoint endTime = Clock::now();
94+
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
95+
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();
96+
97+
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount);
4198

42-
NSLog(@"Decoded tokens: %s", decode_tokens);
99+
NSDictionary *stats = @{
100+
@"totalTime": @(totalDuration),
101+
@"firstTokenTime": @(firstTokenDuration),
102+
@"tokenCount": @(tokenCount)
103+
};
43104

44-
// Add decoded token to SharedTokenUpdater
45-
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
46-
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
47-
}
105+
// notify main thread that token generation is complete
106+
dispatch_async(dispatch_get_main_queue(), ^{
107+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
108+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
109+
});
110+
NSLog(@"Token generation completed.");
48111
}
112+
49113
@end

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,6 @@ Upon app launching, Xcode will automatically copy and install the model files fr
106106

107107
**Note**: The current app only sets up with a simple initial prompt question, you can adjust/try your own or refine the UI based on requirements.
108108

109-
***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.
109+
***Notice:*** The current Xcode project runs on iOS 16.6, feel free to adjust latest iOS/build for lates iOS versions accordingly.
110+
111+
![alt text](<Simulator Screenshot - iPhone 16.png>)

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/SharedTokenUpdater.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,10 @@ import Foundation
1414
self.decodedTokens.append(token)
1515
}
1616
}
17+
18+
@objc func clearTokens() {
19+
DispatchQueue.main.async {
20+
self.decodedTokens.removeAll()
21+
}
22+
}
1723
}
137 KB
Loading

0 commit comments

Comments
 (0)