Skip to content

Commit 030ebaf

Browse files
authored
feat: support chat format (#72)
* feat: add isChatTemplateSupported in model info * feat(ts): add formatChat util * feat(ts): add getFormattedChat native method * feat(ts): completion: add messages * feat(example): use messages * feat(docs): update
1 parent 2f70192 commit 030ebaf

File tree

16 files changed

+422
-172
lines changed

16 files changed

+422
-172
lines changed

README.md

Lines changed: 75 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ You can search HuggingFace for available models (Keyword: [`GGUF`](https://huggi
3535
For create a GGUF model manually, for example in Llama 2:
3636

3737
Download the Llama 2 model
38+
3839
1. Request access from [here](https://ai.meta.com/llama)
3940
2. Download the model from HuggingFace [here](https://huggingface.co/meta-llama/Llama-2-7b-chat) (`Llama-2-7b-chat`)
4041

4142
Convert the model to ggml format
43+
4244
```bash
4345
# Start with submodule in this repo (or you can clone the repo https://github.com/ggerganov/llama.cpp.git)
4446
yarn && yarn bootstrap
@@ -80,26 +82,53 @@ const context = await initLlama({
8082
// embedding: true, // use embedding
8183
})
8284

83-
// Do completion
84-
const { text, timings } = await context.completion(
85+
const stopWords = ['</s>', '<|end|>', '<|eot_id|>', '<|end_of_text|>', '<|im_end|>', '<|EOT|>', '<|END_OF_TURN_TOKEN|>', '<|end_of_turn|>', '<|endoftext|>']
86+
87+
// Do chat completion
88+
const msgResult = await context.completion(
89+
{
90+
messages: [
91+
{
92+
role: 'system',
93+
content: 'This is a conversation between user and assistant, a friendly chatbot.',
94+
},
95+
{
96+
role: 'user',
97+
content: 'Hello!',
98+
},
99+
],
100+
n_predict: 100,
101+
stop: stopWords,
102+
// ...other params
103+
},
104+
(data) => {
105+
// This is a partial completion callback
106+
const { token } = data
107+
},
108+
)
109+
console.log('Result:', msgResult.text)
110+
console.log('Timings:', msgResult.timings)
111+
112+
// Or do text completion
113+
const textResult = await context.completion(
85114
{
86115
prompt: 'This is a conversation between user and llama, a friendly chatbot. respond in simple markdown.\n\nUser: Hello!\nLlama:',
87116
n_predict: 100,
88-
stop: ['</s>', 'Llama:', 'User:'],
89-
// n_threads: 4,
117+
stop: [...stopWords, 'Llama:', 'User:'],
118+
// ...other params
90119
},
91120
(data) => {
92121
// This is a partial completion callback
93122
const { token } = data
94123
},
95124
)
96-
console.log('Result:', text)
97-
console.log('Timings:', timings)
125+
console.log('Result:', textResult.text)
126+
console.log('Timings:', textResult.timings)
98127
```
99128

100129
The binding’s deisgn inspired by [server.cpp](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) example in llama.cpp, so you can map its API to LlamaContext:
101130

102-
- `/completion`: `context.completion(params, partialCompletionCallback)`
131+
- `/completion` and `/chat/completions`: `context.completion(params, partialCompletionCallback)`
103132
- `/tokenize`: `context.tokenize(content)`
104133
- `/detokenize`: `context.detokenize(tokens)`
105134
- `/embedding`: `context.embedding(content)`
@@ -114,6 +143,7 @@ Please visit the [Documentation](docs/API) for more details.
114143
You can also visit the [example](example) to see how to use it.
115144

116145
Run the example:
146+
117147
```bash
118148
yarn && yarn bootstrap
119149

@@ -146,7 +176,9 @@ You can see [GBNF Guide](https://github.com/ggerganov/llama.cpp/tree/master/gram
146176
```js
147177
import { initLlama, convertJsonSchemaToGrammar } from 'llama.rn'
148178

149-
const schema = { /* JSON Schema, see below */ }
179+
const schema = {
180+
/* JSON Schema, see below */
181+
}
150182

151183
const context = await initLlama({
152184
model: 'file://<path to gguf model>',
@@ -157,7 +189,7 @@ const context = await initLlama({
157189
grammar: convertJsonSchemaToGrammar({
158190
schema,
159191
propOrder: { function: 0, arguments: 1 },
160-
})
192+
}),
161193
})
162194

163195
const { text } = await context.completion({
@@ -175,80 +207,81 @@ console.log('Result:', text)
175207
{
176208
oneOf: [
177209
{
178-
type: "object",
179-
name: "get_current_weather",
180-
description: "Get the current weather in a given location",
210+
type: 'object',
211+
name: 'get_current_weather',
212+
description: 'Get the current weather in a given location',
181213
properties: {
182214
function: {
183-
const: "get_current_weather",
215+
const: 'get_current_weather',
184216
},
185217
arguments: {
186-
type: "object",
218+
type: 'object',
187219
properties: {
188220
location: {
189-
type: "string",
190-
description: "The city and state, e.g. San Francisco, CA",
221+
type: 'string',
222+
description: 'The city and state, e.g. San Francisco, CA',
191223
},
192224
unit: {
193-
type: "string",
194-
enum: ["celsius", "fahrenheit"],
225+
type: 'string',
226+
enum: ['celsius', 'fahrenheit'],
195227
},
196228
},
197-
required: ["location"],
229+
required: ['location'],
198230
},
199231
},
200232
},
201233
{
202-
type: "object",
203-
name: "create_event",
204-
description: "Create a calendar event",
234+
type: 'object',
235+
name: 'create_event',
236+
description: 'Create a calendar event',
205237
properties: {
206238
function: {
207-
const: "create_event",
239+
const: 'create_event',
208240
},
209241
arguments: {
210-
type: "object",
242+
type: 'object',
211243
properties: {
212244
title: {
213-
type: "string",
214-
description: "The title of the event",
245+
type: 'string',
246+
description: 'The title of the event',
215247
},
216248
date: {
217-
type: "string",
218-
description: "The date of the event",
249+
type: 'string',
250+
description: 'The date of the event',
219251
},
220252
time: {
221-
type: "string",
222-
description: "The time of the event",
253+
type: 'string',
254+
description: 'The time of the event',
223255
},
224256
},
225-
required: ["title", "date", "time"],
257+
required: ['title', 'date', 'time'],
226258
},
227259
},
228260
},
229261
{
230-
type: "object",
231-
name: "image_search",
232-
description: "Search for an image",
262+
type: 'object',
263+
name: 'image_search',
264+
description: 'Search for an image',
233265
properties: {
234266
function: {
235-
const: "image_search",
267+
const: 'image_search',
236268
},
237269
arguments: {
238-
type: "object",
270+
type: 'object',
239271
properties: {
240272
query: {
241-
type: "string",
242-
description: "The search query",
273+
type: 'string',
274+
description: 'The search query',
243275
},
244276
},
245-
required: ["query"],
277+
required: ['query'],
246278
},
247279
},
248280
},
249281
],
250282
}
251283
```
284+
252285
</details>
253286

254287
<details>
@@ -272,6 +305,7 @@ string ::= "\"" (
272305
2 ::= "{" space "\"function\"" space ":" space 2-function "," space "\"arguments\"" space ":" space 2-arguments "}" space
273306
root ::= 0 | 1 | 2
274307
```
308+
275309
</details>
276310

277311
## Mock `llama.rn`
@@ -285,12 +319,14 @@ jest.mock('llama.rn', () => require('llama.rn/jest/mock'))
285319
## NOTE
286320

287321
iOS:
322+
288323
- The [Extended Virtual Addressing](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_extended-virtual-addressing) capability is recommended to enable on iOS project.
289324
- Metal:
290325
- We have tested to know some devices is not able to use Metal ('params.n_gpu_layers > 0') due to llama.cpp used SIMD-scoped operation, you can check if your device is supported in [Metal feature set tables](https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf), Apple7 GPU will be the minimum requirement.
291326
- It's also not supported in iOS simulator due to [this limitation](https://developer.apple.com/documentation/metal/developing_metal_apps_that_run_in_simulator#3241609), we used constant buffers more than 14.
292327

293328
Android:
329+
294330
- Currently only supported arm64-v8a / x86_64 platform, this means you can't initialize a context on another platforms. The 64-bit platform are recommended because it can allocate more memory for the model.
295331
- No integrated any GPU backend yet.
296332

android/src/main/java/com/rnllama/LlamaContext.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ public WritableMap getModelDetails() {
7575
return modelDetails;
7676
}
7777

78+
public String getFormattedChat(ReadableArray messages, String chatTemplate) {
79+
ReadableMap[] msgs = new ReadableMap[messages.size()];
80+
for (int i = 0; i < messages.size(); i++) {
81+
msgs[i] = messages.getMap(i);
82+
}
83+
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
84+
}
85+
7886
private void emitPartialCompletion(WritableMap tokenResult) {
7987
WritableMap event = Arguments.createMap();
8088
event.putInt("contextId", LlamaContext.this.id);
@@ -316,6 +324,11 @@ protected static native long initContext(
316324
protected static native WritableMap loadModelDetails(
317325
long contextPtr
318326
);
327+
protected static native String getFormattedChat(
328+
long contextPtr,
329+
ReadableMap[] messages,
330+
String chatTemplate
331+
);
319332
protected static native WritableMap loadSession(
320333
long contextPtr,
321334
String path

android/src/main/java/com/rnllama/RNLlama.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,38 @@ protected void onPostExecute(WritableMap result) {
8080
tasks.put(task, "initContext");
8181
}
8282

83+
public void getFormattedChat(double id, final ReadableArray messages, final String chatTemplate, Promise promise) {
84+
final int contextId = (int) id;
85+
AsyncTask task = new AsyncTask<Void, Void, String>() {
86+
private Exception exception;
87+
88+
@Override
89+
protected String doInBackground(Void... voids) {
90+
try {
91+
LlamaContext context = contexts.get(contextId);
92+
if (context == null) {
93+
throw new Exception("Context not found");
94+
}
95+
return context.getFormattedChat(messages, chatTemplate);
96+
} catch (Exception e) {
97+
exception = e;
98+
return null;
99+
}
100+
}
101+
102+
@Override
103+
protected void onPostExecute(String result) {
104+
if (exception != null) {
105+
promise.reject(exception);
106+
return;
107+
}
108+
promise.resolve(result);
109+
tasks.remove(this);
110+
}
111+
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
112+
tasks.put(task, "getFormattedChat-" + contextId);
113+
}
114+
83115
public void loadSession(double id, final String path, Promise promise) {
84116
final int contextId = (int) id;
85117
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {

android/src/main/jni.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,46 @@ Java_com_rnllama_LlamaContext_loadModelDetails(
224224
return reinterpret_cast<jobject>(result);
225225
}
226226

227+
JNIEXPORT jobject JNICALL
228+
Java_com_rnllama_LlamaContext_getFormattedChat(
229+
JNIEnv *env,
230+
jobject thiz,
231+
jlong context_ptr,
232+
jobjectArray messages,
233+
jstring chat_template
234+
) {
235+
UNUSED(thiz);
236+
auto llama = context_map[(long) context_ptr];
237+
238+
std::vector<llama_chat_msg> chat;
239+
240+
int messages_len = env->GetArrayLength(messages);
241+
for (int i = 0; i < messages_len; i++) {
242+
jobject msg = env->GetObjectArrayElement(messages, i);
243+
jclass msgClass = env->GetObjectClass(msg);
244+
245+
jmethodID getRoleMethod = env->GetMethodID(msgClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
246+
jstring roleKey = env->NewStringUTF("role");
247+
jstring contentKey = env->NewStringUTF("content");
248+
249+
jstring role_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, roleKey);
250+
jstring content_str = (jstring) env->CallObjectMethod(msg, getRoleMethod, contentKey);
251+
252+
const char *role = env->GetStringUTFChars(role_str, nullptr);
253+
const char *content = env->GetStringUTFChars(content_str, nullptr);
254+
255+
chat.push_back({ role, content });
256+
257+
env->ReleaseStringUTFChars(role_str, role);
258+
env->ReleaseStringUTFChars(content_str, content);
259+
}
260+
261+
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
262+
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);
263+
264+
return env->NewStringUTF(formatted_chat.c_str());
265+
}
266+
227267
JNIEXPORT jobject JNICALL
228268
Java_com_rnllama_LlamaContext_loadSession(
229269
JNIEnv *env,

android/src/newarch/java/com/rnllama/RNLlamaModule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
4242
rnllama.initContext(params, promise);
4343
}
4444

45+
@ReactMethod
46+
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
47+
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
48+
}
49+
4550
@ReactMethod
4651
public void loadSession(double id, String path, Promise promise) {
4752
rnllama.loadSession(id, path, promise);

android/src/oldarch/java/com/rnllama/RNLlamaModule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ public void initContext(final ReadableMap params, final Promise promise) {
4343
rnllama.initContext(params, promise);
4444
}
4545

46+
@ReactMethod
47+
public void getFormattedChat(double id, ReadableArray messages, String chatTemplate, Promise promise) {
48+
rnllama.getFormattedChat(id, messages, chatTemplate, promise);
49+
}
50+
4651
@ReactMethod
4752
public void loadSession(double id, String path, Promise promise) {
4853
rnllama.loadSession(id, path, promise);

0 commit comments

Comments
 (0)