1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import Tuple
16+
1517import numpy as np
1618import torch
1719from einops import rearrange
2426
2527
2628class CodecEmbedder (nn .Module ):
29+ """
30+ Converts codec codes to dequantized codec embeddings.
31+ The class implements the right API to be used as a custom feature extractor
32+ provided to `torchmetrics.image.fid`.
33+ """
34+
2735 def __init__ (self , codec : AudioCodecModel ):
2836 super ().__init__ ()
2937 self .codec = codec
3038
3139 def forward (self , x : Tensor ) -> Tensor :
3240 """
33- Embeds a batch of audio codec codes into the codec's (dequantized) embedding space.
41+ Embeds a batch of audio codes into the codec's (dequantized) embedding space.
42+ Each frame is treated independently.
43+
44+ Args:
45+ x: Audio codes tensor of shape (B*T, C)
46+
47+ Returns:
48+ Embeddings tensor of shape (B*T, D)
3449 """
35- # x: (B*T, C)
36- x_len = torch .tensor (x .shape [0 ], device = x .device , dtype = torch .long ).unsqueeze (0 ) # (1, 1)
37- # pretend it's one huge batch element, since codec requires (B, C, T) input and
50+ # We treat all frames as one large batch element, since the codec requires (B, C, T) input and
3851 # we don't have the per-batch-element lengths at this point due to FID API limitations
52+
53+ # Consturct a length tensor: one batch element, all frames.
54+ x_len = torch .tensor (x .shape [0 ], device = x .device , dtype = torch .long ).unsqueeze (0 ) # (1, 1)
3955 tokens = x .permute (1 , 0 ).unsqueeze (0 ) # 1, C, B*T
4056 embeddings = self .codec .dequantize (tokens = tokens , tokens_len = x_len ) # (B, D, T)
4157 # we treat each time step as a separate example
@@ -48,7 +64,26 @@ def num_features(self) -> int:
4864
4965
5066class FrechetCodecDistance (FrechetInceptionDistance ):
67+ """
68+ A metric that measures the Frechet Distance between a collection of real and
69+ generated codec frames. The distance is measured in the codec's embedding space,
70+ i.e. the continuous vectors obtained by dequantizing the codec frames. Each
71+ multi-codebook frame is treated as a separate example.
72+
73+ We subclass `torchmetrics.image.fid.FrechetInceptionDistance` and use the codec
74+ embedder as a custom feature extractor.
75+ """
76+
5177 def __init__ (self , codec_name : str ):
78+ """
79+ Initializes the FrechetCodecDistance metric.
80+
81+ Args:
82+ codec_name: The name of the codec model to use.
83+ Can be a local .nemo file or a HuggingFace or NGC model.
84+ If the name ends with ".nemo", it is assumed to be a local .nemo file.
85+ Otherwise, it should start with "nvidia/", and is assumed to be a HuggingFace or NGC model.
86+ """
5287 if codec_name .endswith (".nemo" ):
5388 # Local .nemo file
5489 codec = AudioCodecModel .restore_from (codec_name , strict = False )
@@ -65,9 +100,15 @@ def __init__(self, codec_name: str):
65100 self .codec = codec
66101 self .updated_since_last_reset = False
67102
68- def encode_from_file (self , audio_path : str ) -> Tensor :
103+ def _encode_audio_file (self , audio_path : str ) -> Tuple [ Tensor , Tensor ] :
69104 """
70- Encodes an audio file into audio codec codes.
105+ Encodes an audio file using the audio codec.
106+
107+ Args:
108+ audio_path: Path to the audio file.
109+
110+ Returns:
111+ Tuple of tensors containing the codec codes and the lengths of the codec codes.
71112 """
72113 audio_segment = AudioSegment .from_file (audio_path , target_sr = self .codec .sample_rate )
73114 assert np .issubdtype (audio_segment .samples .dtype , np .floating )
@@ -82,6 +123,14 @@ def encode_from_file(self, audio_path: str) -> Tensor:
82123 return codes , codes_len
83124
84125 def update (self , codes : Tensor , codes_len : Tensor , is_real : bool ):
126+ """
127+ Updates the metric with a batch of codec frames.
128+
129+ Args:
130+ codes: Tensor of shape (B, C, T) containing the codec codes.
131+ codes_len: Tensor of shape (B,) containing the lengths of the codec codes.
132+ is_real: Boolean indicating whether the codes are real or generated.
133+ """
85134 if codes .numel () == 0 :
86135 logging .warning ("FCD: No valid codes to update, skipping update" )
87136 return
@@ -90,29 +139,46 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
90139 f"FCD: Number of codebooks mismatch: { codes .shape [1 ]} != { self .codec .num_codebooks } , skipping update"
91140 )
92141 return
93- # keep only valid codes
142+
143+ # Keep only valid frames
94144 codes_batch_all = []
95145 for batch_idx in range (codes .shape [0 ]):
96146 codes_batch = codes [batch_idx , :, : codes_len [batch_idx ]] # (C, T)
97147 codes_batch_all .append (codes_batch )
98- # combine into a single tensor. We treat each timestep independently so we can concatenate them all.
148+
149+ # Combine into a single tensor. We treat each frame independently so we can concatenate them all.
99150 codes_batch_all = torch .cat (codes_batch_all , dim = - 1 ).permute (1 , 0 ) # (B*T, C)
100151 if len (codes_batch_all ) == 0 :
101152 logging .warning ("FCD: No valid codes to update, skipping update" )
102153 return
103- # update
154+
155+ # Update the metric
104156 super ().update (codes_batch_all , real = is_real )
105157 self .updated_since_last_reset = True
106158
107159 def reset (self ):
160+ """
161+ Resets the metric. Should be called after each compute.
162+ """
108163 super ().reset ()
109164 self .updated_since_last_reset = False
110165
111166 def update_from_audio_file (self , audio_path : str , is_real : bool ):
112- codes , codes_len = self .encode_from_file (audio_path = audio_path )
167+ """
168+ Updates the metric with codes representing a single audio file.
169+ Uses the codec to encode the audio file into codec codes and updates the metric.
170+
171+ Args:
172+ audio_path: Path to the audio file.
173+ is_real: Boolean indicating whether the audio file is real or generated.
174+ """
175+ codes , codes_len = self ._encode_audio_file (audio_path = audio_path )
113176 self .update (codes = codes , codes_len = codes_len , is_real = is_real )
114177
115178 def compute (self ) -> Tensor :
179+ """
180+ Computes the Frechet Distance between the real and generated codec frame distributions.
181+ """
116182 if not self .updated_since_last_reset :
117183 logging .warning ("FCD: No updates since last reset, returning 0" )
118184 return torch .tensor (0.0 , device = self .device )
0 commit comments