Skip to content

Commit 0bfc662

Browse files
fix rome (#133)
1 parent 3cb8827 commit 0bfc662

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

swift/llm/rome.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,25 @@ def rome_infer(args: RomeArguments) -> None:
6565
generation_config.save_pretrained(args.ckpt_dir)
6666
model.generation_config = generation_config
6767

68+
# Inference
6869
if args.eval_human:
6970
while True:
7071
query = input('<<< ')
71-
data = {'query': query}
72-
input_ids = template.encode(data)['input_ids']
73-
inference(input_ids, model, tokenizer, args.stream)
72+
inference(model, template, query, stream=args.stream)
7473
else:
7574
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio,
7675
args.dataset_seed)
7776
mini_val_dataset = val_dataset.select(
7877
range(min(args.show_dataset_sample, val_dataset.shape[0])))
7978
for data in mini_val_dataset:
80-
response = data['response']
81-
data['response'] = None
82-
input_ids = template.encode(data)['input_ids']
83-
inference(input_ids, model, tokenizer, args.stream)
79+
inference(
80+
model,
81+
template,
82+
data.get('query'),
83+
data.get('history'),
84+
data.get('system'),
85+
stream=args.stream)
8486
print()
85-
print(f'[LABELS]{response}')
87+
print(f"[LABELS]{data.get('response')}")
8688
print('-' * 80)
8789
# input('next[ENTER]')

0 commit comments

Comments
 (0)