Skip to content

Commit 11b209e

Browse files
ylacombesanchit-gandhisang-nguyen-ts
authored
Architecture improvements (#65)
* add RoPe * don't include padding in rope * possibly use cross-attn for prompt * fix rope * fix cross-attn * fix self-attn * fix dummy model * clean-up rope * first gqa implementation * fix wer eval * feat: add flash attention and spda * chore: add README for flash attention * chore: add benchmark script * chore: add benchmark attention approach * multi node and fix wer and fix compile * Update modeling_parler_tts.py * fix FA2, SDPA and add cross-attn MHA and attention type forcing * better cross_attention key values number of heads default + add training arguments for attn implementation * fix audio padding when torch compile or pad_to_max_length=True * correct multi node * make rope faster * fix encoder sdpa * fix training with cross attention + with FAZ * use fp32 as default model dtype + fix generation when using FA2 with autocast * remove redundant passes in generate + clean and fix attentions * fix edge case in WER evaluation when longform generation * better multi-node mapping and saving / add eval dataloader num workers * remove old benchmarks * faster audio encoding + checkpointing + fix generation step * better eval + add right padding + fix eval loss compute * correct README * correct config docstrings * remove comment * make style --------- Co-authored-by: sanchit-gandhi <[email protected]> Co-authored-by: sang-nguyen-ts <[email protected]> Co-authored-by: [email protected] <Yoach Lacombe>
1 parent 8b8c576 commit 11b209e

File tree

12 files changed

+1325
-267
lines changed

12 files changed

+1325
-267
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ if torch.xpu.is_available():
5353
device = "xpu"
5454
torch_dtype = torch.float16 if device != "cpu" else torch.float32
5555

56-
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
56+
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1", torch_dtype=torch_dtype).to(device)
57+
5758
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
5859

5960
prompt = "Hey, how are you doing today?"

helpers/model_init_scripts/init_dummy_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
6161
model.generation_config.do_sample = True # True
6262
model.generation_config.guidance_scale = 1 # 3.0
63-
63+
6464
model.config.pad_token_id = encodec_vocab_size
65-
model.config.decoder_start_token_id = encodec_vocab_size+1
65+
model.config.decoder_start_token_id = encodec_vocab_size + 1
6666

6767
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))

helpers/model_init_scripts/init_dummy_model_with_encodec.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,7 @@
5858
model.generation_config.do_sample = True # True
5959
model.generation_config.guidance_scale = 1 # 3.0
6060

61+
model.config.pad_token_id = encodec_vocab_size
62+
model.config.decoder_start_token_id = encodec_vocab_size + 1
63+
6164
model.save_pretrained(os.path.join(args.save_directory, "tiny-model"))

helpers/model_init_scripts/init_model_600M.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
6161
model.generation_config.do_sample = True # True
6262
model.generation_config.guidance_scale = 1 # 3.0
63-
63+
6464
model.config.pad_token_id = encodec_vocab_size
65-
model.config.decoder_start_token_id = encodec_vocab_size+1
65+
model.config.decoder_start_token_id = encodec_vocab_size + 1
6666

6767
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/"))

helpers/push_to_hub_scripts/push_dac_to_hub.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from parler_tts import DACConfig, DACModel
33
from transformers import AutoConfig, AutoModel
44
from transformers import EncodecFeatureExtractor
5+
56
AutoConfig.register("dac", DACConfig)
67
AutoModel.register(DACConfig, DACModel)
78

parler_tts/configuration_parler_tts.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
4747
Number of decoder layers.
4848
num_attention_heads (`int`, *optional*, defaults to 16):
4949
Number of attention heads for each attention layer in the Transformer block.
50+
num_key_value_heads (`int`, *optional*):
51+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55+
by meanpooling all the original heads within that group. For more details checkout [this
56+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57+
`num_attention_heads`.
58+
num_cross_attention_key_value_heads (`int`, *optional*):
59+
This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
60+
If it is not specified, will default to `num_key_value_heads`.
5061
ffn_dim (`int`, *optional*, defaults to 4096):
5162
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
5263
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
7485
The number of parallel codebooks forwarded to the model.
7586
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
7687
Whether input and output word embeddings should be tied.
88+
rope_embeddings (`bool`, *optional*, defaults to `False`):
89+
Whether to use ROPE or absolute positional embeddings.
90+
rope_theta (`float`, *optional*, defaults to 100000.0):
91+
The base period of the RoPE embeddings.
92+
cross_attention_implementation_strategy (`str`, *optional*):
93+
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
7794
"""
7895

7996
model_type = "parler_tts_decoder"
@@ -86,6 +103,8 @@ def __init__(
86103
num_hidden_layers=24,
87104
ffn_dim=4096,
88105
num_attention_heads=16,
106+
num_key_value_heads=None,
107+
num_cross_attention_key_value_heads=None,
89108
layerdrop=0.0,
90109
use_cache=True,
91110
activation_function="gelu",
@@ -100,6 +119,9 @@ def __init__(
100119
bos_token_id=2049,
101120
eos_token_id=2048,
102121
tie_word_embeddings=False,
122+
rope_embeddings=False,
123+
rope_theta=10_000.0,
124+
cross_attention_implementation_strategy=None,
103125
**kwargs,
104126
):
105127
self.vocab_size = vocab_size
@@ -108,6 +130,12 @@ def __init__(
108130
self.ffn_dim = ffn_dim
109131
self.num_hidden_layers = num_hidden_layers
110132
self.num_attention_heads = num_attention_heads
133+
if num_key_value_heads is None:
134+
num_key_value_heads = num_attention_heads
135+
self.num_key_value_heads = num_key_value_heads
136+
if num_cross_attention_key_value_heads is None:
137+
num_cross_attention_key_value_heads = num_key_value_heads
138+
self.num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
111139
self.dropout = dropout
112140
self.attention_dropout = attention_dropout
113141
self.activation_dropout = activation_dropout
@@ -117,6 +145,9 @@ def __init__(
117145
self.use_cache = use_cache
118146
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
119147
self.num_codebooks = num_codebooks
148+
self.rope_embeddings = rope_embeddings
149+
self.rope_theta = rope_theta
150+
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
120151

121152
super().__init__(
122153
pad_token_id=pad_token_id,
@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
140171
vocab_size (`int`, *optional*, defaults to 1024):
141172
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
142173
represented by the `prompt_inputs_ids`.
174+
prompt_cross_attention (`bool`, *optional*, defaults to `False`):
175+
Whether to use cross-attention conditioning for the prompt (as well as the description).
143176
kwargs (*optional*):
144177
Dictionary of keyword arguments. Notably:
145178
@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
190223
model_type = "parler_tts"
191224
is_composition = True
192225

193-
def __init__(self, vocab_size=1024, **kwargs):
226+
def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
194227
super().__init__(**kwargs)
195228
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
196229
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
@@ -204,6 +237,7 @@ def __init__(self, vocab_size=1024, **kwargs):
204237
decoder_config = kwargs.pop("decoder")
205238

206239
self.vocab_size = vocab_size
240+
self.prompt_cross_attention = prompt_cross_attention
207241
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
208242
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
209243
self.decoder = ParlerTTSDecoderConfig(**decoder_config)
@@ -236,3 +270,21 @@ def from_sub_models_config(
236270
# This is a property because you might want to change the codec model on the fly
237271
def sampling_rate(self):
238272
return self.audio_encoder.sampling_rate
273+
274+
# Copy from musicgen
275+
@property
276+
def _attn_implementation(self):
277+
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
278+
if hasattr(self, "_attn_implementation_internal"):
279+
if self._attn_implementation_internal is None:
280+
# `config.attn_implementation` should never be None, for backward compatibility.
281+
return "eager"
282+
else:
283+
return self._attn_implementation_internal
284+
else:
285+
return "eager"
286+
287+
@_attn_implementation.setter
288+
def _attn_implementation(self, value):
289+
self._attn_implementation_internal = value
290+
self.decoder._attn_implementation = value

0 commit comments

Comments
 (0)