diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 8160b5df79c..350595e9cf7 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -173,15 +173,23 @@ def __init__(self, mimi: nn.Module): self.mimi_model = mimi def forward(self, x): - return self.mimi_model.decode(x) + x = x.transpose(1, 2) + x = self.mimi_model.upsample(x) + (emb,) = self.mimi_model.decoder_transformer(x) + emb.transpose(1, 2) + with self.mimi_model._context_for_encoder_decoder: + out = self.mimi_model.decoder(emb) + return out - sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None] - pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate) - chunk = sample_pcm[..., 0:pcm_chunk_size] - input = self.mimi.encode(chunk) + emb_input = torch.rand(1, 1, 512, device="cpu") mimi_decode = MimiDecode(self.mimi) - exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False) + mimi_decode.eval() + mimi_decode(emb_input) + + exported_decode: ExportedProgram = export( + mimi_decode, (emb_input,), strict=False + ) quantization_config = get_symmetric_quantization_config( is_per_channel=True, is_dynamic=True, @@ -190,12 +198,12 @@ def forward(self, x): quantizer.set_global(quantization_config) m = exported_decode.module() m = prepare_pt2e(m, quantizer) - m(input) + m(emb_input) m = convert_pt2e(m) print("quantized graph:") print(m.graph) # Export quantized module - exported_decode: ExportedProgram = export(m, (input,), strict=False) + exported_decode: ExportedProgram = export(m, (emb_input,), strict=False) # Lower edge_manager = to_edge_transform_and_lower( exported_decode, @@ -208,16 +216,16 @@ def forward(self, x): with open(output_file, "wb") as file: exec_prog.write_to_file(file) - eager_res = mimi_decode(input) + eager_res = mimi_decode(emb_input) runtime = Runtime.get() program = runtime.load_program(output_file) method = program.load_method("forward") - flattened_x = tree_flatten(input)[0] + flattened_x = tree_flatten(emb_input)[0] res = method.execute(flattened_x) # Compare results sqnr = compute_sqnr(eager_res, res[0]) print(f"SQNR: {sqnr}") - torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3) if __name__ == "__main__":