@@ -567,7 +567,8 @@ def _run_loglikelihood_tokens(
567567 for field_name , tensors in unpadded_batch .items ()
568568 }
569569
570- batch_logits = log_softmax (model (** padded_batch )[0 ], dim = - 1 )
570+ with torch .no_grad ():
571+ batch_logits = log_softmax (model (** padded_batch )[0 ], dim = - 1 )
571572 z = zip (
572573 batch_of_indices ,
573574 batch_logits ,
@@ -642,8 +643,8 @@ def _run_greedy_until(
642643 if isinstance (untils , str ):
643644 untils = [untils ]
644645 # if any of the stop phrases are single tokens we can use that for early termination
645- primary_until = None
646- for tokenized_until in tokenizer (untils )["input_ids" ]:
646+ primary_until = tokenizer . eos_token_id
647+ for tokenized_until in tokenizer (untils , add_special_tokens = False )["input_ids" ]:
647648 if len (tokenized_until ) == 1 :
648649 primary_until = tokenized_until [0 ]
649650
@@ -652,13 +653,14 @@ def _run_greedy_until(
652653 [tokenized_context [max_gen_toks - model_max_length :]]
653654 ).to (model .device )
654655
655- full_text_tensor = model .generate (
656- context_tensor ,
657- max_length = context_tensor .shape [1 ] + max_gen_toks ,
658- eos_token_id = primary_until ,
659- do_sample = False ,
660- pad_token_id = primary_until , # temporary hack to suppress irrelevant warning until batch processing is added
661- )
656+ with torch .no_grad ():
657+ full_text_tensor = model .generate (
658+ context_tensor ,
659+ max_length = context_tensor .shape [1 ] + max_gen_toks ,
660+ eos_token_id = primary_until ,
661+ do_sample = False ,
662+ pad_token_id = primary_until , # temporary hack to suppress irrelevant warning until batch processing is added
663+ )
662664 continuation_tensor = full_text_tensor [0 , context_tensor .shape [1 ] :]
663665 continuation = tokenizer .decode (continuation_tensor .tolist ())
664666 raw_continuation = continuation
0 commit comments