diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index cc20418d47b..881f8c9371c 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -13,15 +13,28 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from executorch.devtools.backend_debug import print_delegation_info from executorch.exir import to_edge_transform_and_lower +from executorch.runtime import Runtime from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, -) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export, ExportedProgram +from torch.utils._pytree import tree_flatten + +os.environ["https_proxy"] = "http://fwdproxy:8080" + + +def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float: + assert x.shape == y.shape, "Tensor shapes do not match" + x = x.float() + y = y.float() + error = x - y + original_power = torch.mean(torch.pow(x, 2)) + error_power = torch.mean(torch.pow(error, 2)) + sqnr = 10 * torch.log10(original_power / error_power) + return sqnr.item() def read_mp3_from_url(url): @@ -189,6 +202,59 @@ def forward(self, x): ep_encode_output = exported_encode.module()(chunk) self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6)) + def test_exported_decoder_xnnpack(self): + class MimiDecode(nn.Module): + def __init__(self, mimi: nn.Module): + super().__init__() + self.mimi_model = mimi + + def forward(self, x): + return self.mimi_model.decode(x) + + sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None] + pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate) + chunk = sample_pcm[..., 0:pcm_chunk_size] + input = self.mimi.encode(chunk) + + mimi_decode = MimiDecode(self.mimi) + exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False) + quantization_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=True, + ) + quantizer = XNNPACKQuantizer() + quantizer.set_global(quantization_config) + m = exported_decode.module() + m = prepare_pt2e(m, quantizer) + m(input) + m = convert_pt2e(m) + print("quantized graph:") + print(m.graph) + # Export quantized module + exported_decode: ExportedProgram = export(m, (input,), strict=False) + # Lower + edge_manager = to_edge_transform_and_lower( + exported_decode, + partitioner=[XnnpackPartitioner()], + ) + print("delegate graph:") + print_delegation_info(edge_manager.exported_program().graph_module) + exec_prog = edge_manager.to_executorch() + output_file = "/tmp/mimi_decode.pte" + with open(output_file, "wb") as file: + exec_prog.write_to_file(file) + + eager_res = mimi_decode(input) + runtime = Runtime.get() + program = runtime.load_program(output_file) + method = program.load_method("forward") + flattened_x = tree_flatten(input)[0] + res = method.execute(flattened_x) + # Compare results + sqnr = compute_sqnr(eager_res, res[0]) + print(f"SQNR: {sqnr}") + torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3) + if __name__ == "__main__": unittest.main()