Skip to content

Commit 03a30bb

Browse files
committed
add system in swift infer (#508)
1 parent 0945920 commit 03a30bb

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

swift/llm/infer.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def llm_infer(args: InferArguments) -> None:
242242
input_mode: Literal['S', 'M'] = 'S'
243243
logger.info('Input `exit` or `quit` to exit the conversation.')
244244
logger.info('Input `multi-line` to switch to multi-line input mode.')
245+
logger.info(
246+
'Input `reset-system` to reset the system and clear the history.')
245247
if template.support_multi_round:
246248
logger.info('Input `clear` to clear the history.')
247249
else:
@@ -252,11 +254,19 @@ def llm_infer(args: InferArguments) -> None:
252254
if args.infer_media_type != 'none':
253255
logger.info('Please enter the conversation content first, '
254256
'followed by the path to the multimedia file.')
257+
system = None
258+
read_system = False
255259
while True:
256260
if input_mode == 'S':
257-
query = input('<<< ')
261+
addi_prompt = ''
262+
if read_system:
263+
addi_prompt = '[S]'
264+
query = input(f'<<<{addi_prompt} ')
258265
else:
259-
query = read_multi_line()
266+
addi_prompt = '[M]'
267+
if read_system:
268+
addi_prompt = '[MS]'
269+
query = read_multi_line(addi_prompt)
260270
if query.strip().lower() in {'exit', 'quit'}:
261271
break
262272
elif query.strip().lower() == 'clear':
@@ -265,6 +275,13 @@ def llm_infer(args: InferArguments) -> None:
265275
continue
266276
elif query.strip() == '':
267277
continue
278+
elif query.strip().lower() == 'reset-system':
279+
read_system = True
280+
continue
281+
if read_system:
282+
system = query
283+
read_system = False
284+
continue
268285
if input_mode == 'S' and query.strip().lower() == 'multi-line':
269286
input_mode = 'M'
270287
logger.info('End multi-line input with `#`.')
@@ -279,7 +296,11 @@ def llm_infer(args: InferArguments) -> None:
279296
infer_kwargs = {}
280297
read_media_file(infer_kwargs, args.infer_media_type)
281298
if args.infer_backend == 'vllm':
282-
request_list = [{'query': query, 'history': history}]
299+
request_list = [{
300+
'query': query,
301+
'history': history,
302+
'system': system
303+
}]
283304
if args.stream:
284305
gen = inference_stream_vllm(llm_engine, template,
285306
request_list)
@@ -300,7 +321,7 @@ def llm_infer(args: InferArguments) -> None:
300321
else:
301322
if args.stream:
302323
gen = inference_stream(model, template, query, history,
303-
**infer_kwargs)
324+
system, **infer_kwargs)
304325
print_idx = 0
305326
for response, new_history in gen:
306327
if len(response) > print_idx:
@@ -309,7 +330,8 @@ def llm_infer(args: InferArguments) -> None:
309330
print()
310331
else:
311332
response, new_history = inference(model, template, query,
312-
history, **infer_kwargs)
333+
history, system,
334+
**infer_kwargs)
313335
print(response)
314336
print('-' * 50)
315337
obj = {
@@ -366,6 +388,8 @@ def llm_infer(args: InferArguments) -> None:
366388
history = data.get('history')
367389
system = data.get('system')
368390
images = data.get('images')
391+
if args.verbose and system is not None:
392+
print(f'[SYSTEM]{system}')
369393
if history is not None:
370394
kwargs['history'] = history
371395
if system is not None:
@@ -375,7 +399,7 @@ def llm_infer(args: InferArguments) -> None:
375399
if args.infer_backend == 'vllm':
376400
assert args.stream is True
377401
if args.verbose:
378-
print(f"query: {data['query']}\nresponse: ", end='')
402+
print(f"[QUERY]{data['query']}\n[RESPONSE]", end='')
379403
gen = inference_stream_vllm(llm_engine, template, [kwargs])
380404
print_idx = 0
381405
for resp_list in gen:

swift/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def test_time(func: Callable[[], _T],
148148
return res
149149

150150

151-
def read_multi_line() -> str:
151+
def read_multi_line(addi_prompt: str = '') -> str:
152152
res = []
153-
prompt = '<<<[M] '
153+
prompt = f'<<<{addi_prompt} '
154154
while True:
155155
text = input(prompt) + '\n'
156156
prompt = ''

0 commit comments

Comments
 (0)