@@ -47,6 +47,17 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
47
47
Number of decoder layers.
48
48
num_attention_heads (`int`, *optional*, defaults to 16):
49
49
Number of attention heads for each attention layer in the Transformer block.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ num_cross_attention_key_value_heads (`int`, *optional*):
59
+ This is the number of key_value heads that should be used to implement Grouped Query Attention in the cross-attention layers.
60
+ If it is not specified, will default to `num_key_value_heads`.
50
61
ffn_dim (`int`, *optional*, defaults to 4096):
51
62
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer block.
52
63
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
@@ -74,6 +85,12 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
74
85
The number of parallel codebooks forwarded to the model.
75
86
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
76
87
Whether input and output word embeddings should be tied.
88
+ rope_embeddings (`bool`, *optional*, defaults to `False`):
89
+ Whether to use ROPE or absolute positional embeddings.
90
+ rope_theta (`float`, *optional*, defaults to 100000.0):
91
+ The base period of the RoPE embeddings.
92
+ cross_attention_implementation_strategy (`str`, *optional*):
93
+ If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
77
94
"""
78
95
79
96
model_type = "parler_tts_decoder"
@@ -86,6 +103,8 @@ def __init__(
86
103
num_hidden_layers = 24 ,
87
104
ffn_dim = 4096 ,
88
105
num_attention_heads = 16 ,
106
+ num_key_value_heads = None ,
107
+ num_cross_attention_key_value_heads = None ,
89
108
layerdrop = 0.0 ,
90
109
use_cache = True ,
91
110
activation_function = "gelu" ,
@@ -100,6 +119,9 @@ def __init__(
100
119
bos_token_id = 2049 ,
101
120
eos_token_id = 2048 ,
102
121
tie_word_embeddings = False ,
122
+ rope_embeddings = False ,
123
+ rope_theta = 10_000.0 ,
124
+ cross_attention_implementation_strategy = None ,
103
125
** kwargs ,
104
126
):
105
127
self .vocab_size = vocab_size
@@ -108,6 +130,12 @@ def __init__(
108
130
self .ffn_dim = ffn_dim
109
131
self .num_hidden_layers = num_hidden_layers
110
132
self .num_attention_heads = num_attention_heads
133
+ if num_key_value_heads is None :
134
+ num_key_value_heads = num_attention_heads
135
+ self .num_key_value_heads = num_key_value_heads
136
+ if num_cross_attention_key_value_heads is None :
137
+ num_cross_attention_key_value_heads = num_key_value_heads
138
+ self .num_cross_attention_key_value_heads = num_cross_attention_key_value_heads
111
139
self .dropout = dropout
112
140
self .attention_dropout = attention_dropout
113
141
self .activation_dropout = activation_dropout
@@ -117,6 +145,9 @@ def __init__(
117
145
self .use_cache = use_cache
118
146
self .scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
119
147
self .num_codebooks = num_codebooks
148
+ self .rope_embeddings = rope_embeddings
149
+ self .rope_theta = rope_theta
150
+ self .cross_attention_implementation_strategy = cross_attention_implementation_strategy
120
151
121
152
super ().__init__ (
122
153
pad_token_id = pad_token_id ,
@@ -140,6 +171,8 @@ class ParlerTTSConfig(PretrainedConfig):
140
171
vocab_size (`int`, *optional*, defaults to 1024):
141
172
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
142
173
represented by the `prompt_inputs_ids`.
174
+ prompt_cross_attention (`bool`, *optional*, defaults to `False`):
175
+ Whether to use cross-attention conditioning for the prompt (as well as the description).
143
176
kwargs (*optional*):
144
177
Dictionary of keyword arguments. Notably:
145
178
@@ -190,7 +223,7 @@ class ParlerTTSConfig(PretrainedConfig):
190
223
model_type = "parler_tts"
191
224
is_composition = True
192
225
193
- def __init__ (self , vocab_size = 1024 , ** kwargs ):
226
+ def __init__ (self , vocab_size = 1024 , prompt_cross_attention = False , ** kwargs ):
194
227
super ().__init__ (** kwargs )
195
228
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs :
196
229
raise ValueError ("Config has to be initialized with text_encoder, audio_encoder and decoder config" )
@@ -204,6 +237,7 @@ def __init__(self, vocab_size=1024, **kwargs):
204
237
decoder_config = kwargs .pop ("decoder" )
205
238
206
239
self .vocab_size = vocab_size
240
+ self .prompt_cross_attention = prompt_cross_attention
207
241
self .text_encoder = AutoConfig .for_model (text_encoder_model_type , ** text_encoder_config )
208
242
self .audio_encoder = AutoConfig .for_model (audio_encoder_model_type , ** audio_encoder_config )
209
243
self .decoder = ParlerTTSDecoderConfig (** decoder_config )
@@ -236,3 +270,21 @@ def from_sub_models_config(
236
270
# This is a property because you might want to change the codec model on the fly
237
271
def sampling_rate (self ):
238
272
return self .audio_encoder .sampling_rate
273
+
274
+ # Copy from musicgen
275
+ @property
276
+ def _attn_implementation (self ):
277
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
278
+ if hasattr (self , "_attn_implementation_internal" ):
279
+ if self ._attn_implementation_internal is None :
280
+ # `config.attn_implementation` should never be None, for backward compatibility.
281
+ return "eager"
282
+ else :
283
+ return self ._attn_implementation_internal
284
+ else :
285
+ return "eager"
286
+
287
+ @_attn_implementation .setter
288
+ def _attn_implementation (self , value ):
289
+ self ._attn_implementation_internal = value
290
+ self .decoder ._attn_implementation = value
0 commit comments