Skip to content

Commit c495df8

Browse files
committed
feat: add isChatTemplateSupported in model info
1 parent 6d9b5d9 commit c495df8

File tree

10 files changed

+89
-70
lines changed

10 files changed

+89
-70
lines changed

android/src/main/jni.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ static inline void putDouble(JNIEnv *env, jobject map, const char *key, double v
6262
env->CallVoidMethod(map, putDoubleMethod, jKey, value);
6363
}
6464

65+
// Method to put boolean into WritableMap
66+
static inline void putBoolean(JNIEnv *env, jobject map, const char *key, bool value) {
67+
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
68+
jmethodID putBooleanMethod = env->GetMethodID(mapClass, "putBoolean", "(Ljava/lang/String;Z)V");
69+
70+
jstring jKey = env->NewStringUTF(key);
71+
72+
env->CallVoidMethod(map, putBooleanMethod, jKey, value);
73+
}
74+
6575
// Method to put WriteableMap into WritableMap
6676
static inline void putMap(JNIEnv *env, jobject map, const char *key, jobject value) {
6777
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableMap");
@@ -208,6 +218,7 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
208218
putString(env, result, "desc", desc);
209219
putDouble(env, result, "size", llama_model_size(llama->model));
210220
putDouble(env, result, "nParams", llama_model_n_params(llama->model));
221+
putBoolean(env, result, "isChatTemplateSupported", llama->validateModelChatTemplate());
211222
putMap(env, result, "metadata", meta);
212223

213224
return reinterpret_cast<jobject>(result);

cpp/rn-llama.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ struct llama_rn_context
229229
return true;
230230
}
231231

232+
bool validateModelChatTemplate() const {
233+
llama_chat_message chat[] = {{"user", "test"}};
234+
235+
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
236+
237+
return res > 0;
238+
}
239+
232240
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
233241
const int n_left = n_ctx - params.n_keep;
234242
const int n_block_size = n_left / 2;

example/ios/.xcode.env.local

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1722061680584-0.19771203690487615/node
1+
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1722073570606-0.6759511337227031/node

example/ios/Podfile.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ PODS:
88
- hermes-engine/Pre-built (= 0.72.3)
99
- hermes-engine/Pre-built (0.72.3)
1010
- libevent (2.1.12)
11-
- llama-rn (0.3.4):
11+
- llama-rn (0.3.5):
1212
- RCT-Folly
1313
- RCTRequired
1414
- RCTTypeSafety
@@ -1261,7 +1261,7 @@ SPEC CHECKSUMS:
12611261
glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b
12621262
hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322
12631263
libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913
1264-
llama-rn: 1facf2ce116e23e89a526e30439f151eb03f460d
1264+
llama-rn: 1ab4e3bae3136c83dcc2bdcea1ddf0c861335d78
12651265
RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1
12661266
RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18
12671267
RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3

example/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"dependencies": {
1313
"@flyerhq/react-native-chat-ui": "^1.4.3",
1414
"@react-native-clipboard/clipboard": "^1.13.1",
15+
"json5": "^2.2.3",
1516
"react": "18.2.0",
1617
"react-native": "0.72.3",
1718
"react-native-blob-util": "^0.19.1",

example/src/App.tsx

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import DocumentPicker from 'react-native-document-picker'
66
import type { DocumentPickerResponse } from 'react-native-document-picker'
77
import { Chat, darkTheme } from '@flyerhq/react-native-chat-ui'
88
import type { MessageType } from '@flyerhq/react-native-chat-ui'
9+
import json5 from 'json5'
910
import ReactNativeBlobUtil from 'react-native-blob-util'
1011
// eslint-disable-next-line import/no-unresolved
1112
import { initLlama, LlamaContext, convertJsonSchemaToGrammar } from 'llama.rn'
@@ -73,7 +74,7 @@ export default function App() {
7374
}
7475
}
7576

76-
const addSystemMessage = (text: string, metadata = {} ) => {
77+
const addSystemMessage = (text: string, metadata = {}) => {
7778
const textMessage: MessageType.Text = {
7879
author: system,
7980
createdAt: Date.now(),
@@ -119,7 +120,7 @@ export default function App() {
119120
'- /release: release the context\n' +
120121
'- /stop: stop the current completion\n' +
121122
'- /reset: reset the conversation',
122-
'- /save-session: save the session tokens\n' +
123+
'- /save-session: save the session tokens\n' +
123124
'- /load-session: load the session tokens',
124125
)
125126
})
@@ -166,12 +167,18 @@ export default function App() {
166167
const handleSendPress = async (message: MessageType.PartialText) => {
167168
if (context) {
168169
switch (message.text) {
170+
case '/info':
171+
addSystemMessage(
172+
`// Model Info\n${json5.stringify(context.model, null, 2)}`,
173+
{ copyable: true },
174+
)
175+
return
169176
case '/bench':
170177
addSystemMessage('Heating up the model...')
171178
const t0 = Date.now()
172179
await context.bench(8, 4, 1, 1)
173180
const tHeat = Date.now() - t0
174-
if (tHeat > 1E4) {
181+
if (tHeat > 1e4) {
175182
addSystemMessage('Heat up time is too long, please try again.')
176183
return
177184
}
@@ -186,15 +193,21 @@ export default function App() {
186193
ppStd,
187194
tgAvg,
188195
tgStd,
189-
} = await context.bench(512, 128, 1, 3)
196+
} = await context.bench(512, 128, 1, 3)
190197

191-
const size = `${(modelSize / 1024.0 / 1024.0 / 1024.0).toFixed(2)} GiB`
198+
const size = `${(modelSize / 1024.0 / 1024.0 / 1024.0).toFixed(
199+
2,
200+
)} GiB`
192201
const nParams = `${(modelNParams / 1e9).toFixed(2)}B`
193202
const md =
194203
'| model | size | params | test | t/s |\n' +
195204
'| --- | --- | --- | --- | --- |\n' +
196-
`| ${modelDesc} | ${size} | ${nParams} | pp 512 | ${ppAvg.toFixed(2)} ± ${ppStd.toFixed(2)} |\n` +
197-
`| ${modelDesc} | ${size} | ${nParams} | tg 128 | ${tgAvg.toFixed(2)} ± ${tgStd.toFixed(2)}`
205+
`| ${modelDesc} | ${size} | ${nParams} | pp 512 | ${ppAvg.toFixed(
206+
2,
207+
)} ± ${ppStd.toFixed(2)} |\n` +
208+
`| ${modelDesc} | ${size} | ${nParams} | tg 128 | ${tgAvg.toFixed(
209+
2,
210+
)} ± ${tgStd.toFixed(2)}`
198211
addSystemMessage(md, { copyable: true })
199212
return
200213
case '/release':
@@ -208,22 +221,30 @@ export default function App() {
208221
addSystemMessage('Conversation reset!')
209222
return
210223
case '/save-session':
211-
context.saveSession(`${dirs.DocumentDir}/llama-session.bin`).then(tokensSaved => {
212-
console.log('Session tokens saved:', tokensSaved)
213-
addSystemMessage(`Session saved! ${tokensSaved} tokens saved.`)
214-
}).catch(e => {
215-
console.log('Session save failed:', e)
216-
addSystemMessage(`Session save failed: ${e.message}`)
217-
})
224+
context
225+
.saveSession(`${dirs.DocumentDir}/llama-session.bin`)
226+
.then((tokensSaved) => {
227+
console.log('Session tokens saved:', tokensSaved)
228+
addSystemMessage(`Session saved! ${tokensSaved} tokens saved.`)
229+
})
230+
.catch((e) => {
231+
console.log('Session save failed:', e)
232+
addSystemMessage(`Session save failed: ${e.message}`)
233+
})
218234
return
219235
case '/load-session':
220-
context.loadSession(`${dirs.DocumentDir}/llama-session.bin`).then(details => {
221-
console.log('Session loaded:', details)
222-
addSystemMessage(`Session loaded! ${details.tokens_loaded} tokens loaded.`)
223-
}).catch(e => {
224-
console.log('Session load failed:', e)
225-
addSystemMessage(`Session load failed: ${e.message}`)
226-
})
236+
context
237+
.loadSession(`${dirs.DocumentDir}/llama-session.bin`)
238+
.then((details) => {
239+
console.log('Session loaded:', details)
240+
addSystemMessage(
241+
`Session loaded! ${details.tokens_loaded} tokens loaded.`,
242+
)
243+
})
244+
.catch((e) => {
245+
console.log('Session load failed:', e)
246+
addSystemMessage(`Session load failed: ${e.message}`)
247+
})
227248
return
228249
}
229250
}

example/yarn.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3523,7 +3523,7 @@ json-stable-stringify@^1.0.2:
35233523
dependencies:
35243524
jsonify "^0.0.1"
35253525

3526-
json5@^2.1.1, json5@^2.2.2:
3526+
json5@^2.1.1, json5@^2.2.2, json5@^2.2.3:
35273527
version "2.2.3"
35283528
resolved "https://registry.yarnpkg.com/json5/-/json5-2.2.3.tgz#78cd6f1a19bdc12b73db5ad0c61efd66c1e29283"
35293529
integrity sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==

ios/RNLlama.mm

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,7 @@ @implementation RNLlama
5353
@"contextId": contextIdNumber,
5454
@"gpu": @([context isMetalEnabled]),
5555
@"reasonNoGPU": [context reasonNoMetal],
56-
@"model": @{
57-
@"desc": [context modelDesc],
58-
@"size": @([context modelSize]),
59-
@"nParams": @([context modelNParams]),
60-
@"metadata": [context metadata],
61-
}
56+
@"model": [context modelInfo],
6257
});
6358
}
6459

ios/RNLlamaContext.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@
88
bool is_metal_enabled;
99
NSString * reason_no_metal;
1010
bool is_model_loaded;
11-
NSString * model_desc;
12-
uint64_t model_size;
13-
uint64_t model_n_params;
14-
NSDictionary * metadata;
1511

1612
rnllama::llama_rn_context * llama;
1713
}
1814

1915
+ (instancetype)initWithParams:(NSDictionary *)params;
2016
- (bool)isMetalEnabled;
2117
- (NSString *)reasonNoMetal;
22-
- (NSDictionary *)metadata;
23-
- (NSString *)modelDesc;
24-
- (uint64_t)modelSize;
25-
- (uint64_t)modelNParams;
18+
- (NSDictionary *)modelInfo;
2619
- (bool)isModelLoaded;
2720
- (bool)isPredicting;
2821
- (NSDictionary *)completion:(NSDictionary *)params onToken:(void (^)(NSMutableDictionary *tokenResult))onToken;

ios/RNLlamaContext.mm

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -82,26 +82,6 @@ + (instancetype)initWithParams:(NSDictionary *)params {
8282
context->is_metal_enabled = isMetalEnabled;
8383
context->reason_no_metal = reasonNoMetal;
8484

85-
int count = llama_model_meta_count(context->llama->model);
86-
NSDictionary *meta = [[NSMutableDictionary alloc] init];
87-
for (int i = 0; i < count; i++) {
88-
char key[256];
89-
llama_model_meta_key_by_index(context->llama->model, i, key, sizeof(key));
90-
char val[256];
91-
llama_model_meta_val_str_by_index(context->llama->model, i, val, sizeof(val));
92-
93-
NSString *keyStr = [NSString stringWithUTF8String:key];
94-
NSString *valStr = [NSString stringWithUTF8String:val];
95-
[meta setValue:valStr forKey:keyStr];
96-
}
97-
context->metadata = meta;
98-
99-
char desc[1024];
100-
llama_model_desc(context->llama->model, desc, sizeof(desc));
101-
context->model_desc = [NSString stringWithUTF8String:desc];
102-
context->model_size = llama_model_size(context->llama->model);
103-
context->model_n_params = llama_model_n_params(context->llama->model);
104-
10585
return context;
10686
}
10787

@@ -113,20 +93,30 @@ - (NSString *)reasonNoMetal {
11393
return reason_no_metal;
11494
}
11595

116-
- (NSDictionary *)metadata {
117-
return metadata;
118-
}
96+
- (NSDictionary *)modelInfo {
97+
char desc[1024];
98+
llama_model_desc(llama->model, desc, sizeof(desc));
11999

120-
- (NSString *)modelDesc {
121-
return model_desc;
122-
}
100+
int count = llama_model_meta_count(llama->model);
101+
NSDictionary *meta = [[NSMutableDictionary alloc] init];
102+
for (int i = 0; i < count; i++) {
103+
char key[256];
104+
llama_model_meta_key_by_index(llama->model, i, key, sizeof(key));
105+
char val[256];
106+
llama_model_meta_val_str_by_index(llama->model, i, val, sizeof(val));
123107

124-
- (uint64_t)modelSize {
125-
return model_size;
126-
}
108+
NSString *keyStr = [NSString stringWithUTF8String:key];
109+
NSString *valStr = [NSString stringWithUTF8String:val];
110+
[meta setValue:valStr forKey:keyStr];
111+
}
127112

128-
- (uint64_t)modelNParams {
129-
return model_n_params;
113+
return @{
114+
@"desc": [NSString stringWithUTF8String:desc],
115+
@"size": @(llama_model_size(llama->model)),
116+
@"nParams": @(llama_model_n_params(llama->model)),
117+
@"isChatTemplateSupported": @(llama->validateModelChatTemplate()),
118+
@"metadata": meta
119+
};
130120
}
131121

132122
- (bool)isModelLoaded {

0 commit comments

Comments
 (0)