Skip to content

Commit a0d9b5c

Browse files
limintangfacebook-github-bot
authored andcommitted
Update mimi export test (#9755)
Summary: Feed embeddings to mimi decoder directly. Reviewed By: iseeyuan Differential Revision: D72091091
1 parent 7d35c68 commit a0d9b5c

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

examples/models/moshi/mimi/test_mimi.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"https": "http://fwdproxy:8080",
2929
}
3030

31-
3231
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
3332
assert x.shape == y.shape, "Tensor shapes do not match"
3433
x = x.float()
@@ -39,7 +38,6 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
3938
sqnr = 10 * torch.log10(original_power / error_power)
4039
return sqnr.item()
4140

42-
4341
def read_mp3_from_url(url):
4442
try:
4543
response = requests.get(url)
@@ -173,15 +171,23 @@ def __init__(self, mimi: nn.Module):
173171
self.mimi_model = mimi
174172

175173
def forward(self, x):
176-
return self.mimi_model.decode(x)
174+
x = x.transpose(1, 2)
175+
x = self.mimi_model.upsample(x)
176+
(emb,) = self.mimi_model.decoder_transformer(x)
177+
emb.transpose(1, 2)
178+
with self.mimi_model._context_for_encoder_decoder:
179+
out = self.mimi_model.decoder(emb)
180+
return out
177181

178-
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
179-
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
180-
chunk = sample_pcm[..., 0:pcm_chunk_size]
181-
input = self.mimi.encode(chunk)
182+
emb_input = torch.rand(1, 1, 512, device="cpu")
182183

183184
mimi_decode = MimiDecode(self.mimi)
184-
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
185+
mimi_decode.eval()
186+
mimi_decode(emb_input)
187+
188+
exported_decode: ExportedProgram = export(
189+
mimi_decode, (emb_input,), strict=False
190+
)
185191
quantization_config = get_symmetric_quantization_config(
186192
is_per_channel=True,
187193
is_dynamic=True,
@@ -190,12 +196,12 @@ def forward(self, x):
190196
quantizer.set_global(quantization_config)
191197
m = exported_decode.module()
192198
m = prepare_pt2e(m, quantizer)
193-
m(input)
199+
m(emb_input)
194200
m = convert_pt2e(m)
195201
print("quantized graph:")
196202
print(m.graph)
197203
# Export quantized module
198-
exported_decode: ExportedProgram = export(m, (input,), strict=False)
204+
exported_decode: ExportedProgram = export(m, (emb_input,), strict=False)
199205
# Lower
200206
edge_manager = to_edge_transform_and_lower(
201207
exported_decode,
@@ -208,16 +214,16 @@ def forward(self, x):
208214
with open(output_file, "wb") as file:
209215
exec_prog.write_to_file(file)
210216

211-
eager_res = mimi_decode(input)
217+
eager_res = mimi_decode(emb_input)
212218
runtime = Runtime.get()
213219
program = runtime.load_program(output_file)
214220
method = program.load_method("forward")
215-
flattened_x = tree_flatten(input)[0]
221+
flattened_x = tree_flatten(emb_input)[0]
216222
res = method.execute(flattened_x)
217223
# Compare results
218224
sqnr = compute_sqnr(eager_res, res[0])
219225
print(f"SQNR: {sqnr}")
220-
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)
226+
torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3)
221227

222228

223229
if __name__ == "__main__":

0 commit comments

Comments
 (0)