@@ -119,7 +119,7 @@ def put(self, value):
119119 self .first_token .put ((value , self .response_ids [0 ]))
120120
121121 self .is_first_token = False
122- return
122+
123123
124124 self .tokens_cache .append (value )
125125
@@ -356,6 +356,7 @@ def __init__(
356356 total_sample_count = 24576 ,
357357 dataset_path = None ,
358358 workers = 1 ,
359+ ** kwargs ,
359360 ):
360361
361362 super ().__init__ (
@@ -408,9 +409,13 @@ def process_queries(self):
408409 if qitem is None :
409410 break
410411
411- input_ids_tensor = self .data_object .input_ids [qitem .index ]
412- input_masks_tensor = self .data_object .attention_masks [qitem .index ]
413- dataset = self .data_object .dataset_names [qitem .index ]
412+ input_dataset = [self .data_object .dataset_names [qitem .index ]]
413+
414+ batch_texts = [self .data_object .input_texts [qitem .index ]]
415+ batch_ids = self .tokenizer .batch_encode_plus (
416+ batch_texts , return_tensors = "pt" , padding = True )
417+ batch_ids = batch_ids .to (self .device )
418+ _ , length = batch_ids .input_ids .shape
414419
415420 # TODO: This PoC is super slow with significant overhead. Best to
416421 # create a patch to `generate`
@@ -422,32 +427,24 @@ def process_queries(self):
422427 response_ids = [qitem .id ],
423428 )
424429
425- logits_processor = LogitsProcessorList (
426- [StopAfterSequence (
427- self .tokenizer .eos_token_id , device = self .device )]
430+
431+ _ = self .model .generate (
432+ ** batch_ids ,
433+ num_return_sequences = 1 ,
434+ streamer = tokens_streamer ,
435+ ** gen_kwargs ,
428436 )
429- if dataset == "MBXP" :
430- _ = self .model .generate (
431- input_ids = input_ids_tensor ,
432- attention_mask = input_masks_tensor ,
433- pad_token_id = self .tokenizer .pad_token_id ,
434- streamer = tokens_streamer ,
435- logits_processor = logits_processor ,
436- ** gen_kwargs ,
437- )
438- else :
439- _ = self .model .generate (
440- input_ids = input_ids_tensor ,
441- attention_mask = input_masks_tensor ,
442- pad_token_id = self .tokenizer .pad_token_id ,
443- streamer = tokens_streamer ,
444- ** gen_kwargs ,
445- )
446437
447438 output_tokens = tokens_streamer .get_out_tokens ()
448- n_tokens = len (output_tokens )
439+ processed_output = self .data_object .postProcess (
440+ torch .tensor ([output_tokens ], dtype = torch .int64 ),
441+ length = 0 ,
442+ query_id_list = [qitem .index ],
443+ dataset_list = input_dataset ,
444+ )
445+ n_tokens = len (processed_output [0 ])
449446 response_array = array .array (
450- "B" , np .array (output_tokens , np .int32 ).tobytes ()
447+ "B" , np .array (processed_output [ 0 ] , np .int32 ).tobytes ()
451448 )
452449 bi = response_array .buffer_info ()
453450 response = [
0 commit comments