Skip to content

Commit 78d64ed

Browse files
committed
Comments and cleanup
Signed-off-by: Fejgin, Roy <[email protected]>
1 parent 3fc5f37 commit 78d64ed

File tree

1 file changed

+76
-10
lines changed

1 file changed

+76
-10
lines changed

nemo/collections/tts/metrics/frechet_codec_distance.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Tuple
16+
1517
import numpy as np
1618
import torch
1719
from einops import rearrange
@@ -24,18 +26,32 @@
2426

2527

2628
class CodecEmbedder(nn.Module):
29+
"""
30+
Converts codec codes to dequantized codec embeddings.
31+
The class implements the right API to be used as a custom feature extractor
32+
provided to `torchmetrics.image.fid`.
33+
"""
34+
2735
def __init__(self, codec: AudioCodecModel):
2836
super().__init__()
2937
self.codec = codec
3038

3139
def forward(self, x: Tensor) -> Tensor:
3240
"""
33-
Embeds a batch of audio codec codes into the codec's (dequantized) embedding space.
41+
Embeds a batch of audio codes into the codec's (dequantized) embedding space.
42+
Each frame is treated independently.
43+
44+
Args:
45+
x: Audio codes tensor of shape (B*T, C)
46+
47+
Returns:
48+
Embeddings tensor of shape (B*T, D)
3449
"""
35-
# x: (B*T, C)
36-
x_len = torch.tensor(x.shape[0], device=x.device, dtype=torch.long).unsqueeze(0) # (1, 1)
37-
# pretend it's one huge batch element, since codec requires (B, C, T) input and
50+
# We treat all frames as one large batch element, since the codec requires (B, C, T) input and
3851
# we don't have the per-batch-element lengths at this point due to FID API limitations
52+
53+
# Consturct a length tensor: one batch element, all frames.
54+
x_len = torch.tensor(x.shape[0], device=x.device, dtype=torch.long).unsqueeze(0) # (1, 1)
3955
tokens = x.permute(1, 0).unsqueeze(0) # 1, C, B*T
4056
embeddings = self.codec.dequantize(tokens=tokens, tokens_len=x_len) # (B, D, T)
4157
# we treat each time step as a separate example
@@ -48,7 +64,26 @@ def num_features(self) -> int:
4864

4965

5066
class FrechetCodecDistance(FrechetInceptionDistance):
67+
"""
68+
A metric that measures the Frechet Distance between a collection of real and
69+
generated codec frames. The distance is measured in the codec's embedding space,
70+
i.e. the continuous vectors obtained by dequantizing the codec frames. Each
71+
multi-codebook frame is treated as a separate example.
72+
73+
We subclass `torchmetrics.image.fid.FrechetInceptionDistance` and use the codec
74+
embedder as a custom feature extractor.
75+
"""
76+
5177
def __init__(self, codec_name: str):
78+
"""
79+
Initializes the FrechetCodecDistance metric.
80+
81+
Args:
82+
codec_name: The name of the codec model to use.
83+
Can be a local .nemo file or a HuggingFace or NGC model.
84+
If the name ends with ".nemo", it is assumed to be a local .nemo file.
85+
Otherwise, it should start with "nvidia/", and is assumed to be a HuggingFace or NGC model.
86+
"""
5287
if codec_name.endswith(".nemo"):
5388
# Local .nemo file
5489
codec = AudioCodecModel.restore_from(codec_name, strict=False)
@@ -65,9 +100,15 @@ def __init__(self, codec_name: str):
65100
self.codec = codec
66101
self.updated_since_last_reset = False
67102

68-
def encode_from_file(self, audio_path: str) -> Tensor:
103+
def _encode_audio_file(self, audio_path: str) -> Tuple[Tensor, Tensor]:
69104
"""
70-
Encodes an audio file into audio codec codes.
105+
Encodes an audio file using the audio codec.
106+
107+
Args:
108+
audio_path: Path to the audio file.
109+
110+
Returns:
111+
Tuple of tensors containing the codec codes and the lengths of the codec codes.
71112
"""
72113
audio_segment = AudioSegment.from_file(audio_path, target_sr=self.codec.sample_rate)
73114
assert np.issubdtype(audio_segment.samples.dtype, np.floating)
@@ -82,6 +123,14 @@ def encode_from_file(self, audio_path: str) -> Tensor:
82123
return codes, codes_len
83124

84125
def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
126+
"""
127+
Updates the metric with a batch of codec frames.
128+
129+
Args:
130+
codes: Tensor of shape (B, C, T) containing the codec codes.
131+
codes_len: Tensor of shape (B,) containing the lengths of the codec codes.
132+
is_real: Boolean indicating whether the codes are real or generated.
133+
"""
85134
if codes.numel() == 0:
86135
logging.warning("FCD: No valid codes to update, skipping update")
87136
return
@@ -90,29 +139,46 @@ def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
90139
f"FCD: Number of codebooks mismatch: {codes.shape[1]} != {self.codec.num_codebooks}, skipping update"
91140
)
92141
return
93-
# keep only valid codes
142+
143+
# Keep only valid frames
94144
codes_batch_all = []
95145
for batch_idx in range(codes.shape[0]):
96146
codes_batch = codes[batch_idx, :, : codes_len[batch_idx]] # (C, T)
97147
codes_batch_all.append(codes_batch)
98-
# combine into a single tensor. We treat each timestep independently so we can concatenate them all.
148+
149+
# Combine into a single tensor. We treat each frame independently so we can concatenate them all.
99150
codes_batch_all = torch.cat(codes_batch_all, dim=-1).permute(1, 0) # (B*T, C)
100151
if len(codes_batch_all) == 0:
101152
logging.warning("FCD: No valid codes to update, skipping update")
102153
return
103-
# update
154+
155+
# Update the metric
104156
super().update(codes_batch_all, real=is_real)
105157
self.updated_since_last_reset = True
106158

107159
def reset(self):
160+
"""
161+
Resets the metric. Should be called after each compute.
162+
"""
108163
super().reset()
109164
self.updated_since_last_reset = False
110165

111166
def update_from_audio_file(self, audio_path: str, is_real: bool):
112-
codes, codes_len = self.encode_from_file(audio_path=audio_path)
167+
"""
168+
Updates the metric with codes representing a single audio file.
169+
Uses the codec to encode the audio file into codec codes and updates the metric.
170+
171+
Args:
172+
audio_path: Path to the audio file.
173+
is_real: Boolean indicating whether the audio file is real or generated.
174+
"""
175+
codes, codes_len = self._encode_audio_file(audio_path=audio_path)
113176
self.update(codes=codes, codes_len=codes_len, is_real=is_real)
114177

115178
def compute(self) -> Tensor:
179+
"""
180+
Computes the Frechet Distance between the real and generated codec frame distributions.
181+
"""
116182
if not self.updated_since_last_reset:
117183
logging.warning("FCD: No updates since last reset, returning 0")
118184
return torch.tensor(0.0, device=self.device)

0 commit comments

Comments
 (0)