|
40 | 40 |
|
41 | 41 | parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') |
42 | 42 | parser.add_argument('model_path', help='Path to RWKV model in ggml format') |
| 43 | +parser.add_argument('-ngl', '--num_gpu_layers', type=int, default=99, help='Number of layers to run on GPU') |
43 | 44 | add_tokenizer_argument(parser) |
44 | 45 | args = parser.parse_args() |
45 | 46 |
|
|
48 | 49 | with open(script_dir / 'prompt' / f'{LANGUAGE}-{PROMPT_TYPE}.json', 'r', encoding='utf8') as json_file: |
49 | 50 | prompt_data = json.load(json_file) |
50 | 51 |
|
51 | | - user, bot, separator, init_prompt = prompt_data['user'], prompt_data['bot'], prompt_data['separator'], prompt_data['prompt'] |
| 52 | + user, assistant, separator, init_prompt = prompt_data['user'], prompt_data['assistant'], prompt_data['separator'], prompt_data['prompt'] |
52 | 53 |
|
53 | 54 | if init_prompt == '': |
54 | 55 | raise ValueError('Prompt must not be empty') |
|
57 | 58 | print(f'System info: {library.rwkv_get_system_info_string()}') |
58 | 59 |
|
59 | 60 | print('Loading RWKV model') |
60 | | -model = rwkv_cpp_model.RWKVModel(library, args.model_path) |
| 61 | +model = rwkv_cpp_model.RWKVModel(library, args.model_path, gpu_layer_count=args.num_gpu_layers) |
61 | 62 |
|
62 | 63 | tokenizer_decode, tokenizer_encode = get_tokenizer(args.tokenizer, model.n_vocab) |
63 | 64 |
|
@@ -154,7 +155,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: |
154 | 155 | if msg == '+reset': |
155 | 156 | load_thread_state('chat_init') |
156 | 157 | save_thread_state('chat') |
157 | | - print(f'{bot}{separator} Chat reset.\n') |
| 158 | + print(f'{assistant}{separator} Chat reset.\n') |
158 | 159 | continue |
159 | 160 | elif msg[:5].lower() == '+gen ' or msg[:3].lower() == '+i ' or msg[:4].lower() == '+qa ' or msg[:4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': |
160 | 161 |
|
@@ -194,7 +195,7 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: |
194 | 195 | load_thread_state('chat_init') |
195 | 196 |
|
196 | 197 | real_msg = msg[4:].strip() |
197 | | - new = f'{user}{separator} {real_msg}\n\n{bot}{separator}' |
| 198 | + new = f'{user}{separator} {real_msg}\n\n{assistant}{separator}' |
198 | 199 |
|
199 | 200 | process_tokens(tokenizer_encode(new)) |
200 | 201 | save_thread_state('gen_0') |
@@ -225,17 +226,17 @@ def split_last_end_of_line(tokens: List[int]) -> List[int]: |
225 | 226 | except Exception as e: |
226 | 227 | print(e) |
227 | 228 | continue |
228 | | - # chat with bot |
| 229 | + # chat with assistant |
229 | 230 | else: |
230 | 231 | load_thread_state('chat') |
231 | | - new = f'{user}{separator} {msg}\n\n{bot}{separator}' |
| 232 | + new = f'{user}{separator} {msg}\n\n{assistant}{separator}' |
232 | 233 | process_tokens(tokenizer_encode(new), new_line_logit_bias=-999999999) |
233 | 234 | save_thread_state('chat_pre') |
234 | 235 |
|
235 | 236 | thread = 'chat' |
236 | 237 |
|
237 | | - # Print bot response |
238 | | - print(f'> {bot}{separator}', end='') |
| 238 | + # Print assistant response |
| 239 | + print(f'> {assistant}{separator}', end='') |
239 | 240 |
|
240 | 241 | start_index: int = len(processed_tokens) |
241 | 242 | accumulated_tokens: List[int] = [] |
|
0 commit comments