17
17
from transformers import AutoConfig , logging
18
18
from transformers .configuration_utils import PretrainedConfig
19
19
20
+ from importlib .metadata import version
21
+ from packaging .version import Version
22
+
23
+ use_dac_on_the_hub = Version (version ("transformers" )) > Version ("4.44.2dev" )
20
24
21
25
logger = logging .get_logger (__name__ )
22
26
@@ -91,6 +95,10 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
91
95
The base period of the RoPE embeddings.
92
96
cross_attention_implementation_strategy (`str`, *optional*):
93
97
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
98
+ use_fused_lm_heads(`bool`, *optional*, defaults to `False`):
99
+ Whether to fuse audio LM heads instead of applying them sequentially.
100
+ codebook_weights(`List[int]`, *optional*):
101
+ Weights applied to each codebook when computing the loss.
94
102
"""
95
103
96
104
model_type = "parler_tts_decoder"
@@ -122,6 +130,8 @@ def __init__(
122
130
rope_embeddings = False ,
123
131
rope_theta = 10_000.0 ,
124
132
cross_attention_implementation_strategy = None ,
133
+ use_fused_lm_heads = False ,
134
+ codebook_weights = None ,
125
135
** kwargs ,
126
136
):
127
137
self .vocab_size = vocab_size
@@ -148,7 +158,11 @@ def __init__(
148
158
self .rope_embeddings = rope_embeddings
149
159
self .rope_theta = rope_theta
150
160
self .cross_attention_implementation_strategy = cross_attention_implementation_strategy
161
+ self .use_fused_lm_heads = use_fused_lm_heads
162
+ self .codebook_weights = codebook_weights
151
163
164
+ if codebook_weights is not None and len (codebook_weights ) != num_codebooks :
165
+ raise ValueError (f"`codebook_weights` has length { len (codebook_weights )} when it should be of length { num_codebooks } ." )
152
166
super ().__init__ (
153
167
pad_token_id = pad_token_id ,
154
168
bos_token_id = bos_token_id ,
@@ -234,6 +248,11 @@ def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
234
248
audio_encoder_config = kwargs .pop ("audio_encoder" )
235
249
audio_encoder_model_type = audio_encoder_config .pop ("model_type" )
236
250
251
+ model_version = kwargs .get ("transformers_version" , None )
252
+ if model_version is not None and Version (model_version ) <= Version ("4.44.2dev" ) and use_dac_on_the_hub and audio_encoder_model_type == "dac" :
253
+ # here we have to manually change model type if DAC based on transformers version
254
+ audio_encoder_model_type = "dac_on_the_hub"
255
+
237
256
decoder_config = kwargs .pop ("decoder" )
238
257
239
258
self .vocab_size = vocab_size
@@ -269,22 +288,4 @@ def from_sub_models_config(
269
288
@property
270
289
# This is a property because you might want to change the codec model on the fly
271
290
def sampling_rate (self ):
272
- return self .audio_encoder .sampling_rate
273
-
274
- # Copy from musicgen
275
- @property
276
- def _attn_implementation (self ):
277
- # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
278
- if hasattr (self , "_attn_implementation_internal" ):
279
- if self ._attn_implementation_internal is None :
280
- # `config.attn_implementation` should never be None, for backward compatibility.
281
- return "eager"
282
- else :
283
- return self ._attn_implementation_internal
284
- else :
285
- return "eager"
286
-
287
- @_attn_implementation .setter
288
- def _attn_implementation (self , value ):
289
- self ._attn_implementation_internal = value
290
- self .decoder ._attn_implementation = value
291
+ return self .audio_encoder .sampling_rate
0 commit comments