|
5 | 5 | #include "LocalLLM-Swift.h" |
6 | 6 | #include "ort_genai.h" |
7 | 7 | #include "ort_genai_c.h" |
8 | | - |
| 8 | +#include <chrono> |
9 | 9 |
|
10 | 10 | @implementation GenAIGenerator |
11 | 11 |
|
12 | | -+ (void)generate:(nonnull NSString*)input_user_question { |
13 | | - NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
14 | | - const char* modelPath = llmPath.cString; |
15 | | - |
16 | | - auto model = OgaModel::Create(modelPath); |
17 | | - auto tokenizer = OgaTokenizer::Create(*model); |
18 | | - |
19 | | - NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
20 | | - const char* prompt = [promptString UTF8String]; |
21 | | - |
22 | | - auto sequences = OgaSequences::Create(); |
23 | | - tokenizer->Encode(prompt, *sequences); |
| 12 | +typedef std::chrono::high_resolution_clock Clock; |
| 13 | +typedef std::chrono::time_point<Clock> TimePoint; |
24 | 14 |
|
25 | | - auto params = OgaGeneratorParams::Create(*model); |
26 | | - params->SetSearchOption("max_length", 200); |
27 | | - params->SetInputSequences(*sequences); |
28 | | - |
29 | | - // Streaming Output to generate token by token |
30 | | - auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); |
31 | | - |
32 | | - auto generator = OgaGenerator::Create(*model, *params); |
| 15 | ++ (void)generate:(nonnull NSString*)input_user_question { |
| 16 | + NSLog(@"Starting token generation..."); |
| 17 | + |
| 18 | + NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
| 19 | + const char* modelPath = llmPath.cString; |
| 20 | + |
| 21 | + // Log model creation |
| 22 | + NSLog(@"Creating model ..."); |
| 23 | + auto model = OgaModel::Create(modelPath); |
| 24 | + if (!model) { |
| 25 | + NSLog(@"Failed to create model."); |
| 26 | + return; |
| 27 | + } |
| 28 | + |
| 29 | + NSLog(@"Creating tokenizer..."); |
| 30 | + auto tokenizer = OgaTokenizer::Create(*model); |
| 31 | + if (!tokenizer) { |
| 32 | + NSLog(@"Failed to create tokenizer."); |
| 33 | + return; |
| 34 | + } |
| 35 | + |
| 36 | + auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); |
| 37 | + |
| 38 | + // Construct the prompt |
| 39 | + NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
| 40 | + const char* prompt = [promptString UTF8String]; |
| 41 | + |
| 42 | + NSLog(@"Encoding prompt..."); |
| 43 | + auto sequences = OgaSequences::Create(); |
| 44 | + tokenizer->Encode(prompt, *sequences); |
| 45 | + |
| 46 | + // Log parameters |
| 47 | + NSLog(@"Setting generator parameters..."); |
| 48 | + auto params = OgaGeneratorParams::Create(*model); |
| 49 | + params->SetSearchOption("max_length", 200); |
| 50 | + params->SetInputSequences(*sequences); |
| 51 | + |
| 52 | + NSLog(@"Creating generator..."); |
| 53 | + auto generator = OgaGenerator::Create(*model, *params); |
| 54 | + |
| 55 | + bool isFirstToken = true; |
| 56 | + TimePoint startTime = Clock::now(); |
| 57 | + TimePoint firstTokenTime; |
| 58 | + int tokenCount = 0; |
| 59 | + |
| 60 | + NSLog(@"Starting token generation loop..."); |
| 61 | + while (!generator->IsDone()) { |
| 62 | + generator->ComputeLogits(); |
| 63 | + generator->GenerateNextToken(); |
| 64 | + |
| 65 | + if (isFirstToken) { |
| 66 | + NSLog(@"First token generated."); |
| 67 | + firstTokenTime = Clock::now(); |
| 68 | + isFirstToken = false; |
| 69 | + } |
| 70 | + |
| 71 | + // Get the sequence data |
| 72 | + const int32_t* seq = generator->GetSequenceData(0); |
| 73 | + size_t seq_len = generator->GetSequenceCount(0); |
| 74 | + |
| 75 | + // Decode the new token |
| 76 | + const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
| 77 | + |
| 78 | + // Check for decoding failure |
| 79 | + if (!decode_tokens) { |
| 80 | + NSLog(@"Token decoding failed."); |
| 81 | + break; |
| 82 | + } |
| 83 | + |
| 84 | + NSLog(@"Decoded token: %s", decode_tokens); |
| 85 | + tokenCount++; |
| 86 | + |
| 87 | + // Convert token to NSString and update UI on the main thread |
| 88 | + NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
| 89 | + [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
| 90 | + } |
33 | 91 |
|
34 | | - while (!generator->IsDone()) { |
35 | | - generator->ComputeLogits(); |
36 | | - generator->GenerateNextToken(); |
37 | 92 |
|
38 | | - const int32_t* seq = generator->GetSequenceData(0); |
39 | | - size_t seq_len = generator->GetSequenceCount(0); |
40 | | - const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
| 93 | + TimePoint endTime = Clock::now(); |
| 94 | + auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); |
| 95 | + auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); |
| 96 | + |
| 97 | + NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %d", totalDuration, firstTokenDuration, tokenCount); |
41 | 98 |
|
42 | | - NSLog(@"Decoded tokens: %s", decode_tokens); |
| 99 | + NSDictionary *stats = @{ |
| 100 | + @"totalTime": @(totalDuration), |
| 101 | + @"firstTokenTime": @(firstTokenDuration), |
| 102 | + @"tokenCount": @(tokenCount) |
| 103 | + }; |
43 | 104 |
|
44 | | - // Add decoded token to SharedTokenUpdater |
45 | | - NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
46 | | - [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
47 | | - } |
| 105 | + // notify main thread that token generation is complete |
| 106 | + dispatch_async(dispatch_get_main_queue(), ^{ |
| 107 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; |
| 108 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; |
| 109 | + }); |
| 110 | + NSLog(@"Token generation completed."); |
48 | 111 | } |
| 112 | + |
49 | 113 | @end |
0 commit comments