Skip to content

Commit 1eaace8

Browse files
aniljavaAnil Pathak
andauthored
Fix low_level_api_chat_cpp example to match current API (#1086)
* Fix low_level_api_chat_cpp to match current API * Fix low_level_api_chat_cpp to match current API * Using None instead of empty string to so that default prompt template can be used if no prompt provided --------- Co-authored-by: Anil Pathak <[email protected]>
1 parent c689ccc commit 1eaace8

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

examples/low_level_api/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def gpt_params_parse(argv = None):
106106
parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")
107107

108108
parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
109-
parser.add_argument("-p", "--prompt", type=str, default="", help="initial prompt",dest="prompt")
109+
parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
110110
parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
111111
parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
112112
parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")

examples/low_level_api/low_level_api_chat_cpp.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, params: GptParams) -> None:
6262
self.multibyte_fix = []
6363

6464
# model load
65-
self.lparams = llama_cpp.llama_context_default_params()
65+
self.lparams = llama_cpp.llama_model_default_params()
6666
self.lparams.n_ctx = self.params.n_ctx
6767
self.lparams.n_parts = self.params.n_parts
6868
self.lparams.seed = self.params.seed
@@ -72,7 +72,11 @@ def __init__(self, params: GptParams) -> None:
7272

7373
self.model = llama_cpp.llama_load_model_from_file(
7474
self.params.model.encode("utf8"), self.lparams)
75-
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.lparams)
75+
76+
# Context Params.
77+
self.cparams = llama_cpp.llama_context_default_params()
78+
79+
self.ctx = llama_cpp.llama_new_context_with_model(self.model, self.cparams)
7680
if (not self.ctx):
7781
raise RuntimeError(f"error: failed to load model '{self.params.model}'")
7882

@@ -244,7 +248,7 @@ def __init__(self, params: GptParams) -> None:
244248
# tokenize a prompt
245249
def _tokenize(self, prompt, bos=True):
246250
_arr = (llama_cpp.llama_token * ((len(prompt) + 1) * 4))()
247-
_n = llama_cpp.llama_tokenize(self.ctx, prompt.encode("utf8", errors="ignore"), _arr, len(_arr), bos)
251+
_n = llama_cpp.llama_tokenize(self.model, prompt.encode("utf8", errors="ignore"), len(prompt), _arr, len(_arr), bos, False)
248252
return _arr[:_n]
249253

250254
def set_color(self, c):
@@ -304,7 +308,7 @@ def generate(self):
304308
self.n_past += n_eval"""
305309

306310
if (llama_cpp.llama_eval(
307-
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past, self.params.n_threads
311+
self.ctx, (llama_cpp.llama_token * len(self.embd))(*self.embd), len(self.embd), self.n_past
308312
) != 0):
309313
raise Exception("Failed to llama_eval!")
310314

@@ -332,7 +336,7 @@ def generate(self):
332336
id = 0
333337

334338
logits = llama_cpp.llama_get_logits(self.ctx)
335-
n_vocab = llama_cpp.llama_n_vocab(self.ctx)
339+
n_vocab = llama_cpp.llama_n_vocab(self.model)
336340

337341
# Apply params.logit_bias map
338342
for key, value in self.params.logit_bias.items():
@@ -349,12 +353,20 @@ def generate(self):
349353
last_n_repeat = min(len(self.last_n_tokens), repeat_last_n, self.n_ctx)
350354

351355
_arr = (llama_cpp.llama_token * last_n_repeat)(*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat:])
352-
llama_cpp.llama_sample_repetition_penalty(self.ctx, candidates_p,
353-
_arr,
354-
last_n_repeat, llama_cpp.c_float(self.params.repeat_penalty))
355-
llama_cpp.llama_sample_frequency_and_presence_penalties(self.ctx, candidates_p,
356-
_arr,
357-
last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
356+
llama_cpp.llama_sample_repetition_penalties(
357+
ctx=self.ctx,
358+
candidates=candidates_p,
359+
last_tokens_data = _arr,
360+
penalty_last_n = last_n_repeat,
361+
penalty_repeat = llama_cpp.c_float(self.params.repeat_penalty),
362+
penalty_freq = llama_cpp.c_float(self.params.frequency_penalty),
363+
penalty_present = llama_cpp.c_float(self.params.presence_penalty),
364+
)
365+
366+
# NOT PRESENT IN CURRENT VERSION ?
367+
# llama_cpp.llama_sample_frequency_and_presence_penalti(self.ctx, candidates_p,
368+
# _arr,
369+
# last_n_repeat, llama_cpp.c_float(self.params.frequency_penalty), llama_cpp.c_float(self.params.presence_penalty))
358370

359371
if not self.params.penalize_nl:
360372
logits[llama_cpp.llama_token_nl()] = nl_logit
@@ -473,7 +485,7 @@ def exit(self):
473485
def token_to_str(self, token_id: int) -> bytes:
474486
size = 32
475487
buffer = (ctypes.c_char * size)()
476-
n = llama_cpp.llama_token_to_piece_with_model(
488+
n = llama_cpp.llama_token_to_piece(
477489
self.model, llama_cpp.llama_token(token_id), buffer, size)
478490
assert n <= size
479491
return bytes(buffer[:n])
@@ -532,6 +544,9 @@ def interact(self):
532544
print(i,end="",flush=True)
533545
self.params.input_echo = False
534546

547+
# Using string instead of tokens to check for antiprompt,
548+
# It is more reliable than tokens for interactive mode.
549+
generated_str = ""
535550
while self.params.interactive:
536551
self.set_color(util.CONSOLE_COLOR_USER_INPUT)
537552
if (self.params.instruct):
@@ -546,6 +561,10 @@ def interact(self):
546561
try:
547562
for i in self.output():
548563
print(i,end="",flush=True)
564+
generated_str += i
565+
for ap in self.params.antiprompt:
566+
if generated_str.endswith(ap):
567+
raise KeyboardInterrupt
549568
except KeyboardInterrupt:
550569
self.set_color(util.CONSOLE_COLOR_DEFAULT)
551570
if not self.params.instruct:
@@ -561,7 +580,7 @@ def interact(self):
561580
time_now = datetime.now()
562581
prompt = f"""Text transcript of a never ending dialog, where {USER_NAME} interacts with an AI assistant named {AI_NAME}.
563582
{AI_NAME} is helpful, kind, honest, friendly, good at writing and never fails to answer {USER_NAME}’s requests immediately and with details and precision.
564-
There are no annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
583+
Transcript below contains only the recorded dialog between two, without any annotations like (30 seconds passed...) or (to himself), just what {USER_NAME} and {AI_NAME} say aloud to each other.
565584
The dialog lasts for years, the entirety of it is shared below. It's 10000 pages long.
566585
The transcript only includes text, it does not include markup like HTML and Markdown.
567586
@@ -575,8 +594,11 @@ def interact(self):
575594
{AI_NAME}: A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
576595
{USER_NAME}: Name a color.
577596
{AI_NAME}: Blue
578-
{USER_NAME}:"""
597+
{USER_NAME}: """
598+
579599
params = gpt_params_parse()
600+
if params.prompt is None and params.file is None:
601+
params.prompt = prompt
580602

581603
with LLaMAInteract(params) as m:
582604
m.interact()

0 commit comments

Comments
 (0)