Skip to content

Commit 32736a7

Browse files
authored
Bump transformers version (#150)
* add mechanism to reconcile DACs * add mimi compatibility * correct batched generation * update dac on the hub script * update quantizes in training * introduce fused lm heads * update training code and modeling code with per codebook losses and smarter eval * bump transformers version * update transformers version * update bump * update transformers version * update dependencies --------- Co-authored-by: [email protected] <Yoach Lacombe>
1 parent 5d0aca9 commit 32736a7

File tree

10 files changed

+375
-151
lines changed

10 files changed

+375
-151
lines changed

helpers/push_to_hub_scripts/push_dac_to_hub.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from transformers import AutoConfig, AutoModel
66
from transformers import EncodecFeatureExtractor
77

8-
AutoConfig.register("dac", DACConfig)
8+
from importlib.metadata import version
9+
from packaging.version import Version
10+
11+
if Version(version("transformers"))<= Version("4.44.2dev"):
12+
AutoConfig.register("dac", DACConfig)
13+
else:
14+
AutoConfig.register("dac_on_the_hub", DACConfig)
15+
916
AutoModel.register(DACConfig, DACModel)
1017

1118
# Download a model

parler_tts/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,12 @@
1414

1515
from .streamer import ParlerTTSStreamer
1616

17-
AutoConfig.register("dac", DACConfig)
17+
from importlib.metadata import version
18+
from packaging.version import Version
19+
20+
if Version(version("transformers"))<= Version("4.44.2dev"):
21+
AutoConfig.register("dac", DACConfig)
22+
else:
23+
AutoConfig.register("dac_on_the_hub", DACConfig)
24+
1825
AutoModel.register(DACConfig, DACModel)

parler_tts/configuration_parler_tts.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from transformers import AutoConfig, logging
1818
from transformers.configuration_utils import PretrainedConfig
1919

20+
from importlib.metadata import version
21+
from packaging.version import Version
22+
23+
use_dac_on_the_hub = Version(version("transformers")) > Version("4.44.2dev")
2024

2125
logger = logging.get_logger(__name__)
2226

@@ -91,6 +95,10 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
9195
The base period of the RoPE embeddings.
9296
cross_attention_implementation_strategy (`str`, *optional*):
9397
If not specified, the cross-attention implementation will be the same as `_attn_implementation`. If `always_eager`, it will always be the eager implementation. If `always_sdpa`, it will always be the sdpa implementation.
98+
use_fused_lm_heads(`bool`, *optional*, defaults to `False`):
99+
Whether to fuse audio LM heads instead of applying them sequentially.
100+
codebook_weights(`List[int]`, *optional*):
101+
Weights applied to each codebook when computing the loss.
94102
"""
95103

96104
model_type = "parler_tts_decoder"
@@ -122,6 +130,8 @@ def __init__(
122130
rope_embeddings=False,
123131
rope_theta=10_000.0,
124132
cross_attention_implementation_strategy=None,
133+
use_fused_lm_heads=False,
134+
codebook_weights=None,
125135
**kwargs,
126136
):
127137
self.vocab_size = vocab_size
@@ -148,7 +158,11 @@ def __init__(
148158
self.rope_embeddings = rope_embeddings
149159
self.rope_theta = rope_theta
150160
self.cross_attention_implementation_strategy = cross_attention_implementation_strategy
161+
self.use_fused_lm_heads = use_fused_lm_heads
162+
self.codebook_weights = codebook_weights
151163

164+
if codebook_weights is not None and len(codebook_weights) != num_codebooks:
165+
raise ValueError(f"`codebook_weights` has length {len(codebook_weights)} when it should be of length {num_codebooks}.")
152166
super().__init__(
153167
pad_token_id=pad_token_id,
154168
bos_token_id=bos_token_id,
@@ -234,6 +248,11 @@ def __init__(self, vocab_size=1024, prompt_cross_attention=False, **kwargs):
234248
audio_encoder_config = kwargs.pop("audio_encoder")
235249
audio_encoder_model_type = audio_encoder_config.pop("model_type")
236250

251+
model_version = kwargs.get("transformers_version", None)
252+
if model_version is not None and Version(model_version) <= Version("4.44.2dev") and use_dac_on_the_hub and audio_encoder_model_type=="dac":
253+
# here we have to manually change model type if DAC based on transformers version
254+
audio_encoder_model_type = "dac_on_the_hub"
255+
237256
decoder_config = kwargs.pop("decoder")
238257

239258
self.vocab_size = vocab_size
@@ -269,22 +288,4 @@ def from_sub_models_config(
269288
@property
270289
# This is a property because you might want to change the codec model on the fly
271290
def sampling_rate(self):
272-
return self.audio_encoder.sampling_rate
273-
274-
# Copy from musicgen
275-
@property
276-
def _attn_implementation(self):
277-
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
278-
if hasattr(self, "_attn_implementation_internal"):
279-
if self._attn_implementation_internal is None:
280-
# `config.attn_implementation` should never be None, for backward compatibility.
281-
return "eager"
282-
else:
283-
return self._attn_implementation_internal
284-
else:
285-
return "eager"
286-
287-
@_attn_implementation.setter
288-
def _attn_implementation(self, value):
289-
self._attn_implementation_internal = value
290-
self.decoder._attn_implementation = value
291+
return self.audio_encoder.sampling_rate

parler_tts/dac_wrapper/configuration_dac.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11

22
from transformers import PretrainedConfig
3+
from importlib.metadata import version
4+
from packaging.version import Version
35

46

57
class DACConfig(PretrainedConfig):
6-
model_type = "dac"
8+
model_type = "dac" if Version(version("transformers"))<= Version("4.44.2dev") else "dac_on_the_hub"
79

810
def __init__(
911
self,

parler_tts/dac_wrapper/modeling_dac.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22
from dac.model import DAC
3+
from torch import nn
4+
35
from transformers import PreTrainedModel
46
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
57

@@ -11,6 +13,7 @@
1113

1214
class DACModel(PreTrainedModel):
1315
config_class = DACConfig
16+
main_input_name = "input_values"
1417

1518
# Set main input to 'input_values' for voice steering
1619
main_input_name = "input_values"
@@ -23,6 +26,9 @@ def __init__(self, config):
2326
latent_dim=config.latent_dim,
2427
codebook_size=config.codebook_size,
2528
)
29+
30+
self.remove_weight_norm()
31+
self.apply_weight_norm()
2632

2733
def encode(
2834
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
@@ -137,3 +143,22 @@ def decode(
137143

138144
def forward(self, tensor):
139145
raise ValueError("`DACModel.forward` not implemented yet")
146+
147+
148+
def apply_weight_norm(self):
149+
weight_norm = nn.utils.weight_norm
150+
if hasattr(nn.utils.parametrizations, "weight_norm"):
151+
weight_norm = nn.utils.parametrizations.weight_norm
152+
153+
def _apply_weight_norm(module):
154+
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
155+
weight_norm(module)
156+
157+
self.apply(_apply_weight_norm)
158+
159+
160+
def remove_weight_norm(self):
161+
def _remove_weight_norm(module):
162+
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
163+
nn.utils.remove_weight_norm(module)
164+
self.apply(_remove_weight_norm)

0 commit comments

Comments
 (0)