22
22
import torch .nn as nn
23
23
import torch .nn .functional as F
24
24
25
- from ...modeling_utils import PreTrainedModel
25
+ from ...modeling_utils import PreTrainedAudioTokenizerBase
26
26
from ...utils import ModelOutput , auto_docstring
27
27
from ..auto import AutoModel
28
28
from .configuration_xcodec import XcodecConfig
@@ -316,7 +316,7 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
316
316
317
317
318
318
@auto_docstring
319
- class XcodecPreTrainedModel (PreTrainedModel ):
319
+ class XcodecPreTrainedModel (PreTrainedAudioTokenizerBase ):
320
320
"""
321
321
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
322
322
models.
@@ -325,7 +325,6 @@ class XcodecPreTrainedModel(PreTrainedModel):
325
325
config_class = XcodecConfig
326
326
base_model_prefix = "xcodec"
327
327
main_input_name = "input_values"
328
- supports_gradient_checkpointing = False
329
328
330
329
def _init_weights (self , module ):
331
330
"""Initialize the weights"""
@@ -427,34 +426,24 @@ def encode(
427
426
input_values : torch .Tensor ,
428
427
bandwidth : Optional [float ] = None ,
429
428
return_dict : Optional [bool ] = None ,
430
- ** kwargs ,
431
429
) -> Union [torch .Tensor , XcodecEncoderOutput ]:
432
- """
433
- Encodes the input audio waveform into discrete audio codes.
434
-
435
- Args:
436
- input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
437
- Float values of the input audio waveform.
438
- bandwidth (`float`, *optional*):
439
- The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
440
- Defaults to the highest available bandwidth `4.0` kbps.
441
- return_dict (`bool`, *optional*):
442
- Whether or not to return a [`~utils.ModelOutput`].
430
+ r"""
431
+ input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
432
+ Float values of the input audio waveform.
433
+ bandwidth (`float`, *optional*):
434
+ The target bandwidth in (kbps) supports only values in `config.target_bandwidths`.
435
+ Defaults to the highest available bandwidth `4.0` kbps.
436
+ return_dict (`bool`, *optional*):
437
+ Whether or not to return a [`~utils.ModelOutput`].
443
438
444
439
Returns:
445
440
`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)` containing the discrete encoded audio codes.
446
441
"""
447
442
return_dict = return_dict if return_dict is not None else self .config .return_dict
448
443
449
- if input_values .ndim != 3 :
450
- raise ValueError (
451
- f"Expected input shape (batch_size, channels, num_samples), but got shape { input_values .shape } "
452
- )
453
-
454
- _ , channels , self ._input_length = input_values .shape
455
-
456
- if channels not in (1 , 2 ):
457
- raise ValueError (f"Number of audio channels must be 1 or 2, but got { channels } " )
444
+ channels = input_values .shape [1 ]
445
+ if channels != 1 :
446
+ raise ValueError (f"Audio must be mono, but got { channels } " )
458
447
459
448
if bandwidth is None :
460
449
bandwidth = self .config .target_bandwidths [- 1 ]
@@ -483,22 +472,19 @@ def encode(
483
472
484
473
@auto_docstring
485
474
def decode (
486
- self , audio_codes : torch .Tensor , return_dict : Optional [bool ] = None , ** kwargs
475
+ self ,
476
+ audio_codes : torch .Tensor ,
477
+ return_dict : Optional [bool ] = None ,
487
478
) -> Union [torch .Tensor , XcodecDecoderOutput ]:
488
- """
489
- Decode the given discrete codes into an output audio waveform.
490
-
491
- The produced audio waveform is longer than the audio input, so it's automatically trimmed to match the original input.
492
-
493
- Args:
494
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
495
- Discrete code indices computed using `model.encode`.
496
-
497
- return_dict (`bool`, *optional*):
498
- Whether or not to return a [`~utils.ModelOutput`]
479
+ r"""
480
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`):
481
+ Discrete code indices computed using `model.encode`.
482
+ return_dict (`bool`, *optional*):
483
+ Whether or not to return a [`~utils.ModelOutput`]
499
484
500
485
Returns:
501
- Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of Xcodec.
486
+ Decoded audio values of shape `(batch_size, channels, num_samples)` obtained using the decoder part of
487
+ Xcodec.
502
488
"""
503
489
return_dict = return_dict if return_dict is not None else self .config .return_dict
504
490
@@ -507,13 +493,6 @@ def decode(
507
493
quantized_acoustic = self .fc2 (quantized .transpose (1 , 2 )).transpose (1 , 2 )
508
494
audio_values = self .acoustic_decoder (quantized_acoustic )
509
495
510
- if getattr (self , "_input_length" , None ) is not None :
511
- output_length = audio_values .shape [- 1 ]
512
- if self ._input_length != output_length :
513
- extra = output_length - self ._input_length
514
- start = extra // 2
515
- audio_values = audio_values [..., start : start + self ._input_length ]
516
-
517
496
if not return_dict :
518
497
return audio_values
519
498
@@ -526,20 +505,18 @@ def forward(
526
505
audio_codes : Optional [torch .Tensor ] = None ,
527
506
bandwidth : Optional [float ] = None ,
528
507
return_dict : Optional [bool ] = None ,
529
- ** kwargs ,
530
508
) -> Union [tuple [torch .Tensor , torch .Tensor ], XcodecOutput ]:
531
- """
532
- Encodes and quantizes the input audio into discrete codes, then decodes those codes back into an audio waveform.
533
-
534
- Args:
535
- input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
536
- The raw float values of the input audio waveform.
537
- audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
538
- Discrete code indices computed using `model.encode`.
539
- bandwidth (`float`, *optional*):
540
- Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
541
- return_dict (`bool`, *optional*):
542
- Whether to return a [`XcodecOutput`] instead of a plain tuple.
509
+ r"""
510
+ input_values (`torch.FloatTensor` of shape `(batch_size, channels, num_samples)`):
511
+ The raw float values of the input audio waveform.
512
+ audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`:
513
+ Discrete code indices computed using `model.encode`.
514
+ bandwidth (`float`, *optional*):
515
+ Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
516
+ bandwidth (`float`, *optional*):
517
+ Target bandwidth in kbps. Must be one of `config.target_bandwidths`. Defaults to the highest available bandwidth.
518
+ return_dict (`bool`, *optional*):
519
+ Whether to return a [`XcodecOutput`] instead of a plain tuple.
543
520
544
521
Returns:
545
522
`XcodecOutput` or tuple `(audio_codes, audio_values)`:
@@ -568,11 +545,12 @@ def forward(
568
545
```
569
546
"""
570
547
return_dict = return_dict if return_dict is not None else self .config .return_dict
548
+ length = input_values .shape [- 1 ]
571
549
572
550
if audio_codes is None :
573
551
audio_codes = self .encode (input_values , bandwidth , return_dict = False )
574
552
575
- audio_values = self .decode (audio_codes , return_dict = return_dict )[0 ]
553
+ audio_values = self .decode (audio_codes , return_dict = return_dict )[0 ][..., : length ]
576
554
577
555
if not return_dict :
578
556
return (audio_codes , audio_values )
0 commit comments