|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | // Licensed under the MIT License. |
3 | | - |
4 | 3 | #import "GenAIGenerator.h" |
5 | 4 | #include "LocalLLM-Swift.h" |
6 | 5 | #include "ort_genai.h" |
7 | 6 | #include "ort_genai_c.h" |
8 | 7 | #include <chrono> |
| 8 | +#include <vector> |
9 | 9 |
|
10 | 10 | @implementation GenAIGenerator |
11 | 11 |
|
12 | | -typedef std::chrono::high_resolution_clock Clock; |
| 12 | +typedef std::chrono::steady_clock Clock; |
13 | 13 | typedef std::chrono::time_point<Clock> TimePoint; |
| 14 | +static std::unique_ptr<OgaModel> model = nullptr; |
| 15 | +static std::unique_ptr<OgaTokenizer> tokenizer = nullptr; |
14 | 16 |
|
15 | 17 | + (void)generate:(nonnull NSString*)input_user_question { |
16 | | - NSLog(@"Starting token generation..."); |
17 | | - |
18 | | - NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
19 | | - const char* modelPath = llmPath.cString; |
20 | | - |
21 | | - // Log model creation |
22 | | - NSLog(@"Creating model ..."); |
23 | | - auto model = OgaModel::Create(modelPath); |
24 | | - if (!model) { |
25 | | - NSLog(@"Failed to create model."); |
26 | | - return; |
27 | | - } |
28 | | - |
29 | | - NSLog(@"Creating tokenizer..."); |
30 | | - auto tokenizer = OgaTokenizer::Create(*model); |
31 | | - if (!tokenizer) { |
32 | | - NSLog(@"Failed to create tokenizer."); |
33 | | - return; |
34 | | - } |
35 | | - |
36 | | - auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); |
37 | | - |
38 | | - // Construct the prompt |
39 | | - NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
40 | | - const char* prompt = [promptString UTF8String]; |
41 | | - |
42 | | - NSLog(@"Encoding prompt..."); |
43 | | - auto sequences = OgaSequences::Create(); |
44 | | - tokenizer->Encode(prompt, *sequences); |
45 | | - |
46 | | - // Log parameters |
47 | | - NSLog(@"Setting generator parameters..."); |
48 | | - auto params = OgaGeneratorParams::Create(*model); |
49 | | - params->SetSearchOption("max_length", 200); |
50 | | - params->SetInputSequences(*sequences); |
51 | | - |
52 | | - NSLog(@"Creating generator..."); |
53 | | - auto generator = OgaGenerator::Create(*model, *params); |
54 | | - |
55 | | - bool isFirstToken = true; |
56 | | - TimePoint startTime = Clock::now(); |
57 | | - TimePoint firstTokenTime; |
58 | | - int tokenCount = 0; |
59 | | - |
60 | | - NSLog(@"Starting token generation loop..."); |
61 | | - while (!generator->IsDone()) { |
62 | | - generator->ComputeLogits(); |
63 | | - generator->GenerateNextToken(); |
64 | | - |
65 | | - if (isFirstToken) { |
66 | | - NSLog(@"First token generated."); |
67 | | - firstTokenTime = Clock::now(); |
68 | | - isFirstToken = false; |
| 18 | + std::vector<long long> tokenTimes; // per-token generation times |
| 19 | + TimePoint startTime, firstTokenTime, tokenStartTime; |
| 20 | + |
| 21 | + @try { |
| 22 | + NSLog(@"Starting token generation..."); |
| 23 | + |
| 24 | + if (!model) { |
| 25 | + NSLog(@"Creating model..."); |
| 26 | + NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
| 27 | + const char* modelPath = llmPath.cString; |
| 28 | + model = OgaModel::Create(modelPath); // throws exception |
| 29 | + |
| 30 | + if (!model) { |
| 31 | + @throw [NSException exceptionWithName:@"ModelCreationError" reason:@"Failed to create model." userInfo:nil]; |
| 32 | + } |
69 | 33 | } |
70 | | - |
71 | | - // Get the sequence data |
72 | | - const int32_t* seq = generator->GetSequenceData(0); |
73 | | - size_t seq_len = generator->GetSequenceCount(0); |
74 | | - |
75 | | - // Decode the new token |
76 | | - const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
77 | | - |
78 | | - // Check for decoding failure |
79 | | - if (!decode_tokens) { |
80 | | - NSLog(@"Token decoding failed."); |
81 | | - break; |
| 34 | + |
| 35 | + if (!tokenizer) { |
| 36 | + NSLog(@"Creating tokenizer..."); |
| 37 | + tokenizer = OgaTokenizer::Create(*model); // throws exception |
| 38 | + if (!tokenizer) { |
| 39 | + @throw [NSException exceptionWithName:@"TokenizerCreationError" reason:@"Failed to create tokenizer." userInfo:nil]; |
| 40 | + } |
82 | 41 | } |
| 42 | + |
| 43 | + auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); |
| 44 | + |
| 45 | + // Construct the prompt |
| 46 | + NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
| 47 | + const char* prompt = [promptString UTF8String]; |
| 48 | + |
| 49 | + // Encode the prompt |
| 50 | + auto sequences = OgaSequences::Create(); |
| 51 | + tokenizer->Encode(prompt, *sequences); |
| 52 | + |
| 53 | + size_t promptTokensCount = sequences->SequenceCount(0); |
| 54 | + |
| 55 | + NSLog(@"Setting generator parameters..."); |
| 56 | + auto params = OgaGeneratorParams::Create(*model); |
| 57 | + params->SetSearchOption("max_length", 200); |
| 58 | + params->SetInputSequences(*sequences); |
| 59 | + |
| 60 | + auto generator = OgaGenerator::Create(*model, *params); |
| 61 | + |
| 62 | + bool isFirstToken = true; |
| 63 | + NSLog(@"Starting token generation loop..."); |
83 | 64 |
|
84 | | - NSLog(@"Decoded token: %s", decode_tokens); |
85 | | - tokenCount++; |
| 65 | + startTime = Clock::now(); |
| 66 | + while (!generator->IsDone()) { |
| 67 | + tokenStartTime = Clock::now(); |
| 68 | + |
| 69 | + generator->ComputeLogits(); |
| 70 | + generator->GenerateNextToken(); |
| 71 | + |
| 72 | + if (isFirstToken) { |
| 73 | + firstTokenTime = Clock::now(); |
| 74 | + isFirstToken = false; |
| 75 | + } |
| 76 | + |
| 77 | + // Get the sequence data and decode the token |
| 78 | + const int32_t* seq = generator->GetSequenceData(0); |
| 79 | + size_t seq_len = generator->GetSequenceCount(0); |
| 80 | + const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
| 81 | + |
| 82 | + if (!decode_tokens) { |
| 83 | + @throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil]; |
| 84 | + } |
| 85 | + |
| 86 | + // Measure token generation time excluding logging |
| 87 | + TimePoint tokenEndTime = Clock::now(); |
| 88 | + auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count(); |
| 89 | + tokenTimes.push_back(tokenDuration); |
| 90 | + NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
| 91 | + [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
| 92 | + } |
| 93 | + |
| 94 | + TimePoint endTime = Clock::now(); |
| 95 | + // Log token times |
| 96 | + NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]); |
| 97 | + |
| 98 | + // Calculate metrics |
| 99 | + auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); |
| 100 | + auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); |
| 101 | + |
| 102 | + double promtProcTime = (double)promptTokensCount / firstTokenDuration; |
| 103 | + double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration); |
| 104 | + |
| 105 | + NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu", totalDuration, firstTokenDuration, tokenTimes.size()); |
| 106 | + NSLog(@"Prompt tokens: %zu, Prompt Processing Time: %f tokens/s", promptTokensCount, promtProcTime); |
| 107 | + NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate); |
| 108 | + |
86 | 109 |
|
87 | | - // Convert token to NSString and update UI on the main thread |
88 | | - NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
89 | | - [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
90 | | - } |
| 110 | + NSDictionary *stats = @{ |
| 111 | + @"tokenGenRate" : @(tokenGenRate), |
| 112 | + @"promptProcRate": @(promtProcTime) |
| 113 | + }; |
| 114 | + // notify main thread that token generation is complete |
| 115 | + dispatch_async(dispatch_get_main_queue(), ^{ |
| 116 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; |
| 117 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; |
| 118 | + }); |
| 119 | + |
| 120 | + NSLog(@"Token generation completed."); |
91 | 121 |
|
| 122 | + } @catch (NSException* e) { |
| 123 | + NSString* errorMessage = e.reason; |
| 124 | + NSLog(@"Error during generation: %@", errorMessage); |
92 | 125 |
|
93 | | - TimePoint endTime = Clock::now(); |
94 | | - auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); |
95 | | - auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); |
96 | | - |
97 | | - NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount); |
98 | | - |
99 | | - NSDictionary *stats = @{ |
100 | | - @"totalTime": @(totalDuration), |
101 | | - @"firstTokenTime": @(firstTokenDuration), |
102 | | - @"tokenCount": @(tokenCount) |
103 | | - }; |
104 | | - |
105 | | - // notify main thread that token generation is complete |
106 | | - dispatch_async(dispatch_get_main_queue(), ^{ |
107 | | - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; |
108 | | - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; |
109 | | - }); |
110 | | - NSLog(@"Token generation completed."); |
| 126 | + // Send error to the UI |
| 127 | + NSDictionary *errorInfo = @{@"error": errorMessage}; |
| 128 | + dispatch_async(dispatch_get_main_queue(), ^{ |
| 129 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo]; |
| 130 | + }); |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +// Utility function to format token times for logging |
| 135 | ++ (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes { |
| 136 | + NSMutableString *formattedTimes = [NSMutableString string]; |
| 137 | + for (size_t i = 0; i < tokenTimes.size(); i++) { |
| 138 | + [formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]]; |
| 139 | + } |
| 140 | + return [formattedTimes copy]; |
111 | 141 | } |
112 | 142 |
|
113 | 143 | @end |
0 commit comments