File tree Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -65,23 +65,25 @@ def rome_infer(args: RomeArguments) -> None:
6565 generation_config .save_pretrained (args .ckpt_dir )
6666 model .generation_config = generation_config
6767
68+ # Inference
6869 if args .eval_human :
6970 while True :
7071 query = input ('<<< ' )
71- data = {'query' : query }
72- input_ids = template .encode (data )['input_ids' ]
73- inference (input_ids , model , tokenizer , args .stream )
72+ inference (model , template , query , stream = args .stream )
7473 else :
7574 _ , val_dataset = get_dataset (args .dataset , args .dataset_test_ratio ,
7675 args .dataset_seed )
7776 mini_val_dataset = val_dataset .select (
7877 range (min (args .show_dataset_sample , val_dataset .shape [0 ])))
7978 for data in mini_val_dataset :
80- response = data ['response' ]
81- data ['response' ] = None
82- input_ids = template .encode (data )['input_ids' ]
83- inference (input_ids , model , tokenizer , args .stream )
79+ inference (
80+ model ,
81+ template ,
82+ data .get ('query' ),
83+ data .get ('history' ),
84+ data .get ('system' ),
85+ stream = args .stream )
8486 print ()
85- print (f' [LABELS]{ response } ' )
87+ print (f" [LABELS]{ data . get ( ' response' ) } " )
8688 print ('-' * 80 )
8789 # input('next[ENTER]')
You can’t perform that action at this time.
0 commit comments