1
+ from transformers import LogitsProcessor , LogitsProcessorList
2
+ from transformers .pytorch_utils import isin_mps_friendly
3
+ import math
4
+ import torch
5
+
6
+ class ParlerTTSLogitsProcessor (LogitsProcessor ):
7
+ r"""This processor ensures that the delayed pattern mask constraints are respected.
8
+
9
+ <Tip warning={true}>
10
+
11
+ This logits processor is exclusively compatible with Parler-TTS.
12
+ See the model documentation for examples.
13
+
14
+ </Tip>
15
+
16
+ Args:
17
+ eos_token_id (`Union[int, List[int], torch.Tensor]`):
18
+ The id(s) of the *end-of-sequence* token.
19
+ min_eos_p (`float`, *optional*):
20
+ Minimum end of speech threshold.
21
+ """
22
+
23
+ def __init__ (self , eos_token_id , num_codebooks : int , batch_size : int , device : str = "cpu" ):
24
+ if not isinstance (eos_token_id , torch .Tensor ):
25
+ if isinstance (eos_token_id , int ):
26
+ eos_token_id = [eos_token_id ]
27
+ eos_token_id = torch .tensor (eos_token_id , device = device )
28
+ self .eos_token_id = eos_token_id
29
+ self .batch_size = batch_size
30
+
31
+ if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
32
+ raise ValueError (f"`eos_token_id` has to be a list of positive integers, but is { eos_token_id } " )
33
+
34
+ self .num_codebooks = num_codebooks
35
+ self .device = device
36
+
37
+
38
+ self .codebook_idx = torch .arange (self .batch_size * self .num_codebooks , device = self .device )
39
+ self .first_codebooks_unfinished = torch .arange (batch_size , device = device )* num_codebooks
40
+
41
+ max_codebooks = torch .arange (self .batch_size , device = self .device )* self .num_codebooks + self .num_codebooks - 1
42
+ self .max_codebooks = max_codebooks
43
+
44
+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
45
+
46
+ is_eos = isin_mps_friendly (input_ids , self .eos_token_id ).sum (1 )
47
+
48
+ self .first_codebooks_unfinished = torch .where ((is_eos [self .first_codebooks_unfinished ]> 0 ) & (self .first_codebooks_unfinished < self .max_codebooks ), self .first_codebooks_unfinished + 1 , self .first_codebooks_unfinished )
49
+
50
+ # every codebook higher than the first one unfinished will never be eos
51
+ eos_token_mask = self .codebook_idx > self .first_codebooks_unfinished .repeat_interleave (self .num_codebooks )
52
+ scores [eos_token_mask , self .eos_token_id ] = - math .inf
53
+
54
+ return scores
0 commit comments