Skip to content

Commit 46581da

Browse files
committed
Refactor chat UI and token generation logic, and handle token generation errors
1 parent 1b1d999 commit 46581da

File tree

3 files changed

+144
-96
lines changed

3 files changed

+144
-96
lines changed

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
import SwiftUI
25

36

@@ -12,6 +15,8 @@ struct ContentView: View {
1215
@State private var messages: [Message] = [] // Store chat messages locally
1316
@State private var isGenerating: Bool = false // Track token generation state
1417
@State private var stats: String = "" // token genetation stats
18+
@State private var showAlert: Bool = false
19+
@State private var errorMessage: String = ""
1520

1621
var body: some View {
1722
VStack {
@@ -88,12 +93,25 @@ struct ContentView: View {
8893
}
8994
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationStats"))) { notification in
9095
if let userInfo = notification.userInfo,
91-
let totalTime = userInfo["totalTime"] as? Int,
92-
let firstTokenTime = userInfo["firstTokenTime"] as? Int,
93-
let tokenCount = userInfo["tokenCount"] as? Int {
94-
stats = "Generated \(tokenCount) tokens in \(totalTime) ms. First token in \(firstTokenTime) ms."
96+
let promptProcRate = userInfo["promptProcRate"] as? Double,
97+
let tokenGenRate = userInfo["tokenGenRate"] as? Double {
98+
stats = String(format: "Token generation rate: %.2f tokens/s. Prompt processing rate: %.2f tokens/s", tokenGenRate, promptProcRate)
99+
}
100+
}
101+
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationError"))) { notification in
102+
if let userInfo = notification.userInfo, let error = userInfo["error"] as? String {
103+
errorMessage = error
104+
showAlert = true
95105
}
96106
}
107+
.alert(isPresented: $showAlert) {
108+
Alert(
109+
title: Text("Error"),
110+
message: Text(errorMessage),
111+
dismissButton: .default(Text("OK"))
112+
)
113+
}
114+
97115
}
98116
}
99117

@@ -117,7 +135,7 @@ struct ChatBubble: View {
117135
.background(Color(.systemGray5))
118136
.foregroundColor(.black)
119137
.cornerRadius(25)
120-
.padding(.horizontal, 20)
138+
.padding(.horizontal, 10)
121139
Spacer()
122140
}
123141
}
Lines changed: 121 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,143 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3-
43
#import "GenAIGenerator.h"
54
#include "LocalLLM-Swift.h"
65
#include "ort_genai.h"
76
#include "ort_genai_c.h"
87
#include <chrono>
8+
#include <vector>
99

1010
@implementation GenAIGenerator
1111

12-
typedef std::chrono::high_resolution_clock Clock;
12+
typedef std::chrono::steady_clock Clock;
1313
typedef std::chrono::time_point<Clock> TimePoint;
14+
static std::unique_ptr<OgaModel> model = nullptr;
15+
static std::unique_ptr<OgaTokenizer> tokenizer = nullptr;
1416

1517
+ (void)generate:(nonnull NSString*)input_user_question {
16-
NSLog(@"Starting token generation...");
17-
18-
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
19-
const char* modelPath = llmPath.cString;
20-
21-
// Log model creation
22-
NSLog(@"Creating model ...");
23-
auto model = OgaModel::Create(modelPath);
24-
if (!model) {
25-
NSLog(@"Failed to create model.");
26-
return;
27-
}
28-
29-
NSLog(@"Creating tokenizer...");
30-
auto tokenizer = OgaTokenizer::Create(*model);
31-
if (!tokenizer) {
32-
NSLog(@"Failed to create tokenizer.");
33-
return;
34-
}
35-
36-
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
37-
38-
// Construct the prompt
39-
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
40-
const char* prompt = [promptString UTF8String];
41-
42-
NSLog(@"Encoding prompt...");
43-
auto sequences = OgaSequences::Create();
44-
tokenizer->Encode(prompt, *sequences);
45-
46-
// Log parameters
47-
NSLog(@"Setting generator parameters...");
48-
auto params = OgaGeneratorParams::Create(*model);
49-
params->SetSearchOption("max_length", 200);
50-
params->SetInputSequences(*sequences);
51-
52-
NSLog(@"Creating generator...");
53-
auto generator = OgaGenerator::Create(*model, *params);
54-
55-
bool isFirstToken = true;
56-
TimePoint startTime = Clock::now();
57-
TimePoint firstTokenTime;
58-
int tokenCount = 0;
59-
60-
NSLog(@"Starting token generation loop...");
61-
while (!generator->IsDone()) {
62-
generator->ComputeLogits();
63-
generator->GenerateNextToken();
64-
65-
if (isFirstToken) {
66-
NSLog(@"First token generated.");
67-
firstTokenTime = Clock::now();
68-
isFirstToken = false;
18+
std::vector<long long> tokenTimes; // per-token generation times
19+
TimePoint startTime, firstTokenTime, tokenStartTime;
20+
21+
@try {
22+
NSLog(@"Starting token generation...");
23+
24+
if (!model) {
25+
NSLog(@"Creating model...");
26+
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
27+
const char* modelPath = llmPath.cString;
28+
model = OgaModel::Create(modelPath); // throws exception
29+
30+
if (!model) {
31+
@throw [NSException exceptionWithName:@"ModelCreationError" reason:@"Failed to create model." userInfo:nil];
32+
}
6933
}
70-
71-
// Get the sequence data
72-
const int32_t* seq = generator->GetSequenceData(0);
73-
size_t seq_len = generator->GetSequenceCount(0);
74-
75-
// Decode the new token
76-
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
77-
78-
// Check for decoding failure
79-
if (!decode_tokens) {
80-
NSLog(@"Token decoding failed.");
81-
break;
34+
35+
if (!tokenizer) {
36+
NSLog(@"Creating tokenizer...");
37+
tokenizer = OgaTokenizer::Create(*model); // throws exception
38+
if (!tokenizer) {
39+
@throw [NSException exceptionWithName:@"TokenizerCreationError" reason:@"Failed to create tokenizer." userInfo:nil];
40+
}
8241
}
42+
43+
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
44+
45+
// Construct the prompt
46+
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
47+
const char* prompt = [promptString UTF8String];
48+
49+
// Encode the prompt
50+
auto sequences = OgaSequences::Create();
51+
tokenizer->Encode(prompt, *sequences);
52+
53+
size_t promptTokensCount = sequences->SequenceCount(0);
54+
55+
NSLog(@"Setting generator parameters...");
56+
auto params = OgaGeneratorParams::Create(*model);
57+
params->SetSearchOption("max_length", 200);
58+
params->SetInputSequences(*sequences);
59+
60+
auto generator = OgaGenerator::Create(*model, *params);
61+
62+
bool isFirstToken = true;
63+
NSLog(@"Starting token generation loop...");
8364

84-
NSLog(@"Decoded token: %s", decode_tokens);
85-
tokenCount++;
65+
startTime = Clock::now();
66+
while (!generator->IsDone()) {
67+
tokenStartTime = Clock::now();
68+
69+
generator->ComputeLogits();
70+
generator->GenerateNextToken();
71+
72+
if (isFirstToken) {
73+
firstTokenTime = Clock::now();
74+
isFirstToken = false;
75+
}
76+
77+
// Get the sequence data and decode the token
78+
const int32_t* seq = generator->GetSequenceData(0);
79+
size_t seq_len = generator->GetSequenceCount(0);
80+
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
81+
82+
if (!decode_tokens) {
83+
@throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil];
84+
}
85+
86+
// Measure token generation time excluding logging
87+
TimePoint tokenEndTime = Clock::now();
88+
auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count();
89+
tokenTimes.push_back(tokenDuration);
90+
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
91+
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
92+
}
93+
94+
TimePoint endTime = Clock::now();
95+
// Log token times
96+
NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]);
97+
98+
// Calculate metrics
99+
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
100+
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();
101+
102+
double promtProcTime = (double)promptTokensCount / firstTokenDuration;
103+
double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration);
104+
105+
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu", totalDuration, firstTokenDuration, tokenTimes.size());
106+
NSLog(@"Prompt tokens: %zu, Prompt Processing Time: %f tokens/s", promptTokensCount, promtProcTime);
107+
NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate);
108+
86109

87-
// Convert token to NSString and update UI on the main thread
88-
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
89-
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
90-
}
110+
NSDictionary *stats = @{
111+
@"tokenGenRate" : @(tokenGenRate),
112+
@"promptProcRate": @(promtProcTime)
113+
};
114+
// notify main thread that token generation is complete
115+
dispatch_async(dispatch_get_main_queue(), ^{
116+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
117+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
118+
});
119+
120+
NSLog(@"Token generation completed.");
91121

122+
} @catch (NSException* e) {
123+
NSString* errorMessage = e.reason;
124+
NSLog(@"Error during generation: %@", errorMessage);
92125

93-
TimePoint endTime = Clock::now();
94-
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
95-
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();
96-
97-
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount);
98-
99-
NSDictionary *stats = @{
100-
@"totalTime": @(totalDuration),
101-
@"firstTokenTime": @(firstTokenDuration),
102-
@"tokenCount": @(tokenCount)
103-
};
104-
105-
// notify main thread that token generation is complete
106-
dispatch_async(dispatch_get_main_queue(), ^{
107-
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
108-
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
109-
});
110-
NSLog(@"Token generation completed.");
126+
// Send error to the UI
127+
NSDictionary *errorInfo = @{@"error": errorMessage};
128+
dispatch_async(dispatch_get_main_queue(), ^{
129+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo];
130+
});
131+
}
132+
}
133+
134+
// Utility function to format token times for logging
135+
+ (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes {
136+
NSMutableString *formattedTimes = [NSMutableString string];
137+
for (size_t i = 0; i < tokenTimes.size(); i++) {
138+
[formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]];
139+
}
140+
return [formattedTimes copy];
111141
}
112142

113143
@end
5.3 KB
Loading

0 commit comments

Comments
 (0)