Skip to content

Commit 862f841

Browse files
eustlbsanchit-gandhiylacombesang-nguyen-tssanchit-gandhi
authored
Add static cache (#89)
* add rope * don't include padding in rope * possibly use cross-attn for prompt * fix rope * fix cross-attn * fix self-attn * fix dummy model * clean-up rope * first gqa implementation * fix wer eval * feat: add flash attention and spda * chore: add README for flash attention * chore: add benchmark script * chore: add benchmark attention approach * multi node and fix wer and fix compile * Update modeling_parler_tts.py * fix FA2, SDPA and add cross-attn MHA and attention type forcing * better cross_attention key values number of heads default + add training arguments for attn implementation * fix audio padding when torch compile or pad_to_max_length=True * correct multi node * make rope faster * fix encoder sdpa * fix training with cross attention + with FAZ * use fp32 as default model dtype + fix generation when using FA2 with autocast * remove redundant passes in generate + clean and fix attentions * fix edge case in WER evaluation when longform generation * better multi-node mapping and saving / add eval dataloader num workers * remove old benchmarks * faster audio encoding + checkpointing + fix generation step * unpin trfms * remove CFG * imports and constants Co-Authored-By: sang-nguyen-ts <[email protected]> * attention modifications to handle static cach Co-Authored-By: sang-nguyen-ts <[email protected]> * decoder layer modification to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSPreTrainedModel modifs to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSDecoder modifs to handle static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSModel + ParlerTTSForCausalLM modfis Co-Authored-By: sang-nguyen-ts <[email protected]> * ParlerTTSForConditionalGeneration modifs Co-Authored-By: sang-nguyen-ts <[email protected]> * decoder_attention_mask for static cache Co-Authored-By: sang-nguyen-ts <[email protected]> * create inputs_embeds early to have a good cache initialization Co-Authored-By: sang-nguyen-ts <[email protected]> * _get_cache method Co-Authored-By: sang-nguyen-ts <[email protected]> * init the cache Co-Authored-By: sang-nguyen-ts <[email protected]> * ensure good device Co-Authored-By: sang-nguyen-ts <[email protected]> * pin tfrms version Co-Authored-By: sang-nguyen-ts <[email protected]> * fix attention_mask FA2 * remove unnecessary method * Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <[email protected]> * Update parler_tts/modeling_parler_tts.py Co-authored-by: Sanchit Gandhi <[email protected]> * remove unnecessary imports * replace the hardcoded cache_position with a more elegant approach * make style * unpin transformers * pin transformers * pin torch * refactor + unpin torch * Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <[email protected]> * update training script to match 11b209e * Update parler_tts/modeling_parler_tts.py Co-authored-by: Yoach Lacombe <[email protected]> * ensure compatibility with trfms 4.43.3, changes taken from #31980 on trfms * fix input_ids_length * warning full attention mask creation * changes for training compatibility --------- Co-authored-by: sanchit-gandhi <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]> Co-authored-by: sang-nguyen-ts <[email protected]> Co-authored-by: [email protected] <Yoach Lacombe> Co-authored-by: sang-nguyen-ts <[email protected]> Co-authored-by: Sanchit Gandhi <[email protected]>
1 parent 11b209e commit 862f841

File tree

13 files changed

+576
-346
lines changed

13 files changed

+576
-346
lines changed

helpers/gradio_demo/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import gradio as gr
22
import torch
3+
from transformers import AutoFeatureExtractor, AutoTokenizer, set_seed
34

45
from parler_tts import ParlerTTSForConditionalGeneration
5-
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
6+
67

78
device = "cuda:0" if torch.cuda.is_available() else "cpu"
89

@@ -57,7 +58,7 @@ def gen_tts(text, description):
5758
background-color: #000000;
5859
justify-content: center;
5960
align-items: center;
60-
border-radius: 9999px !important;
61+
border-radius: 9999px !important;
6162
width: 13rem;
6263
margin-top: 10px;
6364
margin-left: auto;

helpers/model_init_scripts/init_dummy_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
2-
from transformers import AutoConfig
3-
import os
41
import argparse
2+
import os
3+
4+
from transformers import AutoConfig
5+
6+
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
57

68

79
if __name__ == "__main__":

helpers/model_init_scripts/init_dummy_model_with_encodec.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
2-
from transformers import AutoConfig
3-
import os
41
import argparse
2+
import os
3+
4+
from transformers import AutoConfig
5+
6+
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
7+
58

69
if __name__ == "__main__":
710
parser = argparse.ArgumentParser()

helpers/model_init_scripts/init_model_600M.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
2-
from transformers import AutoConfig
3-
import os
41
import argparse
2+
import os
3+
4+
from transformers import AutoConfig
5+
6+
from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration
57

68

79
if __name__ == "__main__":

helpers/push_to_hub_scripts/push_dac_to_hub.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import dac
2+
from transformers import AutoConfig, AutoModel, EncodecFeatureExtractor
3+
24
from parler_tts import DACConfig, DACModel
35
from transformers import AutoConfig, AutoModel
46
from transformers import EncodecFeatureExtractor

helpers/push_to_hub_scripts/push_trained_parler_tts_to_hub.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from transformers import AutoFeatureExtractor, AutoTokenizer
2+
13
from parler_tts import ParlerTTSForConditionalGeneration
2-
from transformers import AutoTokenizer, AutoFeatureExtractor
4+
35

46
path = "TODO"
57
repo_id = "parler_tts_600M"

parler_tts/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
__version__ = "0.1"
22

33

4+
from transformers import AutoConfig, AutoModel
5+
46
from .configuration_parler_tts import ParlerTTSConfig, ParlerTTSDecoderConfig
7+
from .dac_wrapper import DACConfig, DACModel
58
from .modeling_parler_tts import (
69
ParlerTTSForCausalLM,
710
ParlerTTSForConditionalGeneration,
811
apply_delay_pattern_mask,
912
build_delay_pattern_mask,
1013
)
1114

12-
from .dac_wrapper import DACConfig, DACModel
13-
from transformers import AutoConfig, AutoModel
1415

1516
AutoConfig.register("dac", DACConfig)
1617
AutoModel.register(DACConfig, DACModel)

parler_tts/dac_wrapper/configuration_dac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
12
from transformers import PretrainedConfig
2-
from typing import List
33

44

55
class DACConfig(PretrainedConfig):

parler_tts/dac_wrapper/modeling_dac.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
2-
2+
from dac.model import DAC
33
from transformers import PreTrainedModel
4-
from transformers.models.encodec.modeling_encodec import EncodecEncoderOutput, EncodecDecoderOutput
5-
from .configuration_dac import DACConfig
4+
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
65

7-
from dac.model import DAC
6+
from .configuration_dac import DACConfig
87

98

109
# model doesn't support batching yet
@@ -134,4 +133,4 @@ def decode(
134133
return EncodecDecoderOutput(audio_values)
135134

136135
def forward(self, tensor):
137-
raise ValueError(f"`DACModel.forward` not implemented yet")
136+
raise ValueError("`DACModel.forward` not implemented yet")

0 commit comments

Comments
 (0)