9
9
#import " LLaMARunner.h"
10
10
11
11
#import < ExecuTorch/ExecuTorchLog.h>
12
- #import < executorch/examples/models/llama/runner/runner.h>
12
+ #import < executorch/extension/llm/runner/text_llm_runner.h>
13
+ #import < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
13
14
14
- using executorch::extension::llm::GenerationConfig;
15
- using executorch::extension::llm::TextLLMRunner;
16
- using executorch::runtime::Error;
15
+ using namespace executorch ::extension;
16
+ using namespace executorch ::runtime;
17
17
18
18
NSErrorDomain const LLaMARunnerErrorDomain = @" LLaMARunnerErrorDomain" ;
19
19
20
20
@interface LLaMARunner ()<ExecuTorchLogSink>
21
21
@end
22
22
23
23
@implementation LLaMARunner {
24
- std::unique_ptr<TextLLMRunner> _runner;
24
+ std::unique_ptr<llm:: TextLLMRunner> _runner;
25
25
}
26
26
27
27
- (instancetype )initWithModelPath : (NSString *)modelPath
28
28
tokenizerPath : (NSString *)tokenizerPath {
29
29
self = [super init ];
30
30
if (self) {
31
31
[ExecuTorchLog.sharedLog addSink: self ];
32
- _runner = example::create_llama_runner (
33
- modelPath.UTF8String , tokenizerPath.UTF8String );
32
+ _runner = llm::create_text_llm_runner (
33
+ modelPath.UTF8String ,
34
+ llm::load_tokenizer (
35
+ tokenizerPath.UTF8String ,
36
+ example::get_special_tokens (example::Version::Default)
37
+ )
38
+ );
34
39
}
35
40
return self;
36
41
}
@@ -60,20 +65,19 @@ - (BOOL)generate:(NSString*)prompt
60
65
sequenceLength : (NSInteger )seq_len
61
66
withTokenCallback : (nullable void (^)(NSString *))callback
62
67
error:(NSError **)error {
63
- const GenerationConfig config{
64
- .seq_len = static_cast <int32_t >(seq_len)
65
- };
66
68
const auto status = _runner->generate (
67
- prompt.UTF8String , config, [callback](const std::string& token) {
69
+ prompt.UTF8String ,
70
+ llm::GenerationConfig{.seq_len = static_cast <int32_t >(seq_len)},
71
+ [callback](const std::string& token) {
68
72
callback (@(token.c_str ()));
69
73
});
70
74
if (status != Error::Ok) {
71
75
if (error) {
72
76
*error = [NSError errorWithDomain: LLaMARunnerErrorDomain
73
77
code: (NSInteger )status
74
78
userInfo: nil ];
75
- return NO ;
76
79
}
80
+ return NO ;
77
81
}
78
82
return YES ;
79
83
}
@@ -95,15 +99,16 @@ - (void)logWithLevel:(ExecuTorchLogLevel)level
95
99
NSUInteger seconds = totalSeconds % 60 ;
96
100
NSUInteger microseconds = (timestamp - totalSeconds) * 1000000 ;
97
101
NSLog (
98
- @" %c %02lu :%02lu :%02lu .%06lu executorch:%s :%zu ] %s " ,
99
- (char )level,
100
- hours,
101
- minutes,
102
- seconds,
103
- microseconds,
104
- filename.UTF8String,
105
- line,
106
- message.UTF8String);
102
+ @" %c %02lu :%02lu :%02lu .%06lu executorch:%s :%zu ] %s " ,
103
+ (char )level,
104
+ hours,
105
+ minutes,
106
+ seconds,
107
+ microseconds,
108
+ filename.UTF8String,
109
+ line,
110
+ message.UTF8String
111
+ );
107
112
}
108
113
109
114
@end
0 commit comments