Skip to content

Commit 3d1b82a

Browse files
authored
add logits processor (#173)
* add logits processor * bump version
1 parent 77d10df commit 3d1b82a

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

parler_tts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.1"
1+
__version__ = "0.2.2"
22

33

44
from transformers import AutoConfig, AutoModel

parler_tts/logits_processors.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from transformers import LogitsProcessor, LogitsProcessorList
2+
from transformers.pytorch_utils import isin_mps_friendly
3+
import math
4+
import torch
5+
6+
class ParlerTTSLogitsProcessor(LogitsProcessor):
7+
r"""This processor ensures that the delayed pattern mask constraints are respected.
8+
9+
<Tip warning={true}>
10+
11+
This logits processor is exclusively compatible with Parler-TTS.
12+
See the model documentation for examples.
13+
14+
</Tip>
15+
16+
Args:
17+
eos_token_id (`Union[int, List[int], torch.Tensor]`):
18+
The id(s) of the *end-of-sequence* token.
19+
min_eos_p (`float`, *optional*):
20+
Minimum end of speech threshold.
21+
"""
22+
23+
def __init__(self, eos_token_id, num_codebooks: int, batch_size: int, device: str = "cpu"):
24+
if not isinstance(eos_token_id, torch.Tensor):
25+
if isinstance(eos_token_id, int):
26+
eos_token_id = [eos_token_id]
27+
eos_token_id = torch.tensor(eos_token_id, device=device)
28+
self.eos_token_id = eos_token_id
29+
self.batch_size = batch_size
30+
31+
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
32+
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
33+
34+
self.num_codebooks = num_codebooks
35+
self.device = device
36+
37+
38+
self.codebook_idx = torch.arange(self.batch_size*self.num_codebooks, device=self.device)
39+
self.first_codebooks_unfinished = torch.arange(batch_size, device=device)*num_codebooks
40+
41+
max_codebooks = torch.arange(self.batch_size, device=self.device)*self.num_codebooks + self.num_codebooks -1
42+
self.max_codebooks = max_codebooks
43+
44+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
45+
46+
is_eos = isin_mps_friendly(input_ids, self.eos_token_id).sum(1)
47+
48+
self.first_codebooks_unfinished = torch.where((is_eos[self.first_codebooks_unfinished]>0) & (self.first_codebooks_unfinished<self.max_codebooks), self.first_codebooks_unfinished+1, self.first_codebooks_unfinished)
49+
50+
# every codebook higher than the first one unfinished will never be eos
51+
eos_token_mask = self.codebook_idx > self.first_codebooks_unfinished.repeat_interleave(self.num_codebooks)
52+
scores[eos_token_mask, self.eos_token_id] = -math.inf
53+
54+
return scores

parler_tts/modeling_parler_tts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060

6161
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
6262
from .dac_wrapper import DACConfig, DACModel
63+
from .logits_processors import ParlerTTSLogitsProcessor
6364

6465
from importlib.metadata import version
6566
from packaging.version import Version
@@ -3401,9 +3402,6 @@ def generate(
34013402
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
34023403

34033404
# 2. Set generation parameters if not already defined
3404-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
3405-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
3406-
34073405
requires_attention_mask = "encoder_outputs" not in model_kwargs
34083406
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
34093407

@@ -3414,6 +3412,9 @@ def generate(
34143412
batch_size = inputs_tensor.shape[0]
34153413
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
34163414

3415+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList([ParlerTTSLogitsProcessor(generation_config.eos_token_id, self.decoder.num_codebooks, batch_size, inputs_tensor.device)])
3416+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
3417+
34173418
# 4. Define other model kwargs
34183419
model_kwargs["use_cache"] = generation_config.use_cache
34193420

0 commit comments

Comments
 (0)