1+ from transformers import LogitsProcessor , LogitsProcessorList
2+ from transformers .pytorch_utils import isin_mps_friendly
3+ import math
4+ import torch
5+
6+ class ParlerTTSLogitsProcessor (LogitsProcessor ):
7+ r"""This processor ensures that the delayed pattern mask constraints are respected.
8+
9+ <Tip warning={true}>
10+
11+ This logits processor is exclusively compatible with Parler-TTS.
12+ See the model documentation for examples.
13+
14+ </Tip>
15+
16+ Args:
17+ eos_token_id (`Union[int, List[int], torch.Tensor]`):
18+ The id(s) of the *end-of-sequence* token.
19+ min_eos_p (`float`, *optional*):
20+ Minimum end of speech threshold.
21+ """
22+
23+ def __init__ (self , eos_token_id , num_codebooks : int , batch_size : int , device : str = "cpu" ):
24+ if not isinstance (eos_token_id , torch .Tensor ):
25+ if isinstance (eos_token_id , int ):
26+ eos_token_id = [eos_token_id ]
27+ eos_token_id = torch .tensor (eos_token_id , device = device )
28+ self .eos_token_id = eos_token_id
29+ self .batch_size = batch_size
30+
31+ if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
32+ raise ValueError (f"`eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
33+
34+ self .num_codebooks = num_codebooks
35+ self .device = device
36+
37+
38+ self .codebook_idx = torch .arange (self .batch_size * self .num_codebooks , device = self .device )
39+ self .first_codebooks_unfinished = torch .arange (batch_size , device = device )* num_codebooks
40+
41+ max_codebooks = torch .arange (self .batch_size , device = self .device )* self .num_codebooks + self .num_codebooks - 1
42+ self .max_codebooks = max_codebooks
43+
44+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
45+
46+ is_eos = isin_mps_friendly (input_ids , self .eos_token_id ).sum (1 )
47+
48+ self .first_codebooks_unfinished = torch .where ((is_eos [self .first_codebooks_unfinished ]> 0 ) & (self .first_codebooks_unfinished < self .max_codebooks ), self .first_codebooks_unfinished + 1 , self .first_codebooks_unfinished )
49+
50+ # every codebook higher than the first one unfinished will never be eos
51+ eos_token_mask = self .codebook_idx > self .first_codebooks_unfinished .repeat_interleave (self .num_codebooks )
52+ scores [eos_token_mask , self .eos_token_id ] = - math .inf
53+
54+ return scores
0 commit comments