Skip to content

Commit fd38dc8

Browse files
ebezzamvasqu
authored andcommitted
Clean up XCodec and other codecs (huggingface#40348)
* Clean up xcodec addition. * Clean up config. * Switch to fixtures test. * Small stuff. * Polish XCodec and standardize across codecs. * Update src/transformers/models/xcodec/modeling_xcodec.py Co-authored-by: Anton Vlasjuk <[email protected]> * Format and fix test. * Update tol. --------- Co-authored-by: Anton Vlasjuk <[email protected]>
1 parent 3c9dbe9 commit fd38dc8

File tree

6 files changed

+50
-69
lines changed

6 files changed

+50
-69
lines changed

src/transformers/models/dac/feature_extraction_dac.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,11 @@ def __call__(
150150
max_length=max_length,
151151
truncation=truncation,
152152
padding=padding,
153-
return_attention_mask=False,
153+
return_attention_mask=padding,
154154
pad_to_multiple_of=self.hop_length,
155155
)
156-
156+
if padding:
157+
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
157158
if padding:
158159
padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
159160

src/transformers/models/dac/modeling_dac.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,8 @@ def decode(
613613
The codebook indices for each codebook, representing the quantized discrete
614614
representation of the input. This parameter should be provided if you want
615615
to decode directly from the audio codes (it will overwrite quantized_representation).
616+
return_dict (`bool`, *optional*, defaults to `True`):
617+
Whether to return a [`DacDecoderOutput`] instead of a plain tuple.
616618
"""
617619

618620
if quantized_representation is None and audio_codes is None:
@@ -667,6 +669,7 @@ def forward(
667669

668670
return_dict = return_dict if return_dict is not None else self.config.return_dict
669671
length = input_values.shape[-1]
672+
670673
loss, quantized_representation, audio_codes, projected_latents = self.encode(
671674
input_values, n_quantizers, return_dict=False
672675
)

src/transformers/models/encodec/modeling_encodec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
from torch import nn
2323

24-
from ...modeling_utils import PreTrainedModel
24+
from ...modeling_utils import PreTrainedAudioTokenizerBase
2525
from ...utils import (
2626
ModelOutput,
2727
auto_docstring,
@@ -449,7 +449,7 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
449449

450450

451451
@auto_docstring
452-
class EncodecPreTrainedModel(PreTrainedModel):
452+
class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase):
453453
config: EncodecConfig
454454
base_model_prefix = "encodec"
455455
main_input_name = "input_values"

src/transformers/models/xcodec/modeling_xcodec.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424

25-
from ...modeling_utils import PreTrainedModel
25+
from ...modeling_utils import PreTrainedAudioTokenizerBase
2626
from ...utils import ModelOutput, auto_docstring
2727
from ..auto import AutoModel
2828
from .configuration_xcodec import XcodecConfig
@@ -316,7 +316,7 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
316316

317317

318318
@auto_docstring
319-
class XcodecPreTrainedModel(PreTrainedModel):
319+
class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase):
320320
"""
321321
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
322322
models.
@@ -325,7 +325,6 @@ class XcodecPreTrainedModel(PreTrainedModel):
325325
config_class = XcodecConfig
326326
base_model_prefix = "xcodec"
327327
main_input_name = "input_values"
328-
supports_gradient_checkpointing = False
329328

330329
def _init_weights(self, module):
331330
"""Initialize the weights"""
@@ -427,34 +426,24 @@ def encode(
427426
input_values: torch.Tensor,
428427
bandwidth: Optional[float] = None,
429428
return_dict: Optional[bool] = None,
430-
**kwargs,
431429
) -> Union[torch.Tensor, XcodecEncoderOutput]:
432-
"""
433-
Encodes the input audio waveform into discrete audio codes.
434-
435-
Args:
436-
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
437-
Float values of the input audio waveform.
438-
bandwidth (`float`, *optional*):
439-
The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
440-
Defaults to the highest available bandwidth `4.0` kbps.
441-
return_dict (`bool`, *optional*):
442-
Whether or not to return a [`~utils.ModelOutput`].
430+
r"""
431+
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
432+
Float values of the input audio waveform.
433+
bandwidth (`float`, *optional*):
434+
The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
435+
Defaults to the highest available bandwidth `4.0` kbps.
436+
return_dict (`bool`, *optional*):
437+
Whether or not to return a [`~utils.ModelOutput`].
443438
444439
Returns:
445440
`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
446441
"""
447442
return_dict = return_dict if return_dict is not None else self.config.return_dict
448443

449-
if input_values.ndim != 3:
450-
raise ValueError(
451-
f"Expected input shape (batch_size, channels, num_samples), but got shape {input_values.shape}"
452-
)
453-
454-
_, channels, self._input_length = input_values.shape
455-
456-
if channels not in (1, 2):
457-
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
444+
channels = input_values.shape[1]
445+
if channels != 1:
446+
raise ValueError(f"Audio must be mono, but got {channels}")
458447

459448
if bandwidth is None:
460449
bandwidth = self.config.target_bandwidths[-1]
@@ -483,22 +472,19 @@ def encode(
483472

484473
@auto_docstring
485474
def decode(
486-
self, audio_codes: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
475+
self,
476+
audio_codes: torch.Tensor,
477+
return_dict: Optional[bool] = None,
487478
) -> Union[torch.Tensor, XcodecDecoderOutput]:
488-
"""
489-
Decode the given discrete codes into an output audio waveform.
490-
491-
The produced audio waveform is longer than the audio input, so it's automatically trimmed to match the original input.
492-
493-
Args:
494-
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
495-
Discrete code indices computed using `model.encode`.
496-
497-
return_dict (`bool`, *optional*):
498-
Whether or not to return a [`~utils.ModelOutput`]
479+
r"""
480+
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
481+
Discrete code indices computed using `model.encode`.
482+
return_dict (`bool`, *optional*):
483+
Whether or not to return a [`~utils.ModelOutput`]
499484
500485
Returns:
501-
Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of Xcodec.
486+
Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
487+
Xcodec.
502488
"""
503489
return_dict = return_dict if return_dict is not None else self.config.return_dict
504490

@@ -507,13 +493,6 @@ def decode(
507493
quantized_acoustic = self.fc2(quantized.transpose(1, 2)).transpose(1, 2)
508494
audio_values = self.acoustic_decoder(quantized_acoustic)
509495

510-
if getattr(self, "_input_length", None) is not None:
511-
output_length = audio_values.shape[-1]
512-
if self._input_length != output_length:
513-
extra = output_length - self._input_length
514-
start = extra // 2
515-
audio_values = audio_values[..., start : start + self._input_length]
516-
517496
if not return_dict:
518497
return audio_values
519498

@@ -526,20 +505,18 @@ def forward(
526505
audio_codes: Optional[torch.Tensor] = None,
527506
bandwidth: Optional[float] = None,
528507
return_dict: Optional[bool] = None,
529-
**kwargs,
530508
) -> Union[tuple[torch.Tensor, torch.Tensor], XcodecOutput]:
531-
"""
532-
Encodes and quantizes the input audio into discrete codes, then decodes those codes back into an audio waveform.
533-
534-
Args:
535-
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
536-
The raw float values of the input audio waveform.
537-
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
538-
Discrete code indices computed using `model.encode`.
539-
bandwidth (`float`, *optional*):
540-
Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
541-
return_dict (`bool`, *optional*):
542-
Whether to return a [`XcodecOutput`] instead of a plain tuple.
509+
r"""
510+
input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
511+
The raw float values of the input audio waveform.
512+
audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
513+
Discrete code indices computed using `model.encode`.
514+
bandwidth (`float`, *optional*):
515+
Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
516+
bandwidth (`float`, *optional*):
517+
Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
518+
return_dict (`bool`, *optional*):
519+
Whether to return a [`XcodecOutput`] instead of a plain tuple.
543520
544521
Returns:
545522
`XcodecOutput` or tuple `(audio_codes, audio_values)`:
@@ -568,11 +545,12 @@ def forward(
568545
```
569546
"""
570547
return_dict = return_dict if return_dict is not None else self.config.return_dict
548+
length = input_values.shape[-1]
571549

572550
if audio_codes is None:
573551
audio_codes = self.encode(input_values, bandwidth, return_dict=False)
574552

575-
audio_values = self.decode(audio_codes, return_dict=return_dict)[0]
553+
audio_values = self.decode(audio_codes, return_dict=return_dict)[0][..., :length]
576554

577555
if not return_dict:
578556
return (audio_codes, audio_values)

tests/models/dac/test_modeling_dac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def compute_rmse(arr1, arr2):
527527
}
528528
EXPECTED_QUANT_CODEBOOK_LOSS = {
529529
"dac_16khz": 20.7299,
530-
"dac_24khz": 22.6652,
530+
"dac_24khz": 22.6602,
531531
"dac_44khz": 16.2168,
532532
}
533533
EXPECTED_CODEC_ERROR = {
@@ -793,7 +793,7 @@ def test_integration(self, model_name):
793793
atol=1e-6,
794794
)
795795
torch.testing.assert_close(
796-
quantizer_outputs[4].squeeze().item(), EXPECTED_QUANT_CODEBOOK_LOSS[model_name], rtol=1e-6, atol=1e-6
796+
quantizer_outputs[4].squeeze().item(), EXPECTED_QUANT_CODEBOOK_LOSS[model_name], rtol=1e-4, atol=1e-4
797797
)
798798

799799
# compare decoder outputs

tests/models/xcodec/test_modeling_xcodec.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ def __init__(
6767
self.num_samples = num_samples
6868

6969
def prepare_config_and_inputs(self):
70-
input_values = floats_tensor([self.batch_size, self.num_channels, self.num_samples], scale=1.0)
7170
config = self.get_config()
72-
inputs_dict = {"input_values": input_values}
71+
inputs_dict = {
72+
"input_values": floats_tensor([self.batch_size, self.num_channels, self.num_samples], scale=1.0)
73+
}
7374
return config, inputs_dict
7475

7576
def prepare_config_and_inputs_for_common(self):
@@ -82,7 +83,6 @@ def prepare_config_and_inputs_for_model_class(self, model_class):
8283
inputs_dict["audio_codes"] = ids_tensor(
8384
[self.batch_size, config.num_quantizers, codes_length], config.codebook_size
8485
)
85-
8686
return config, inputs_dict
8787

8888
def get_config(self):
@@ -94,8 +94,7 @@ def get_config(self):
9494

9595
def create_and_check_model_forward(self, config, inputs_dict):
9696
model = XcodecModel(config=config).to(torch_device).eval()
97-
input_values = inputs_dict["input_values"]
98-
result = model(input_values)
97+
result = model(input_values=inputs_dict["input_values"])
9998
self.parent.assertEqual(result.audio_values.shape, (self.batch_size, self.num_channels, self.num_samples))
10099

101100

0 commit comments

Comments
 (0)