Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions examples/models/moshi/mimi/test_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.devtools.backend_debug import print_delegation_info
from executorch.exir import to_edge_transform_and_lower
from executorch.runtime import Runtime

from huggingface_hub import hf_hub_download
from moshi.models import loaders
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export, ExportedProgram
from torch.utils._pytree import tree_flatten

os.environ["https_proxy"] = "http://fwdproxy:8080"


def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
assert x.shape == y.shape, "Tensor shapes do not match"
x = x.float()
y = y.float()
error = x - y
original_power = torch.mean(torch.pow(x, 2))
error_power = torch.mean(torch.pow(error, 2))
sqnr = 10 * torch.log10(original_power / error_power)
return sqnr.item()


def read_mp3_from_url(url):
Expand Down Expand Up @@ -189,6 +202,59 @@ def forward(self, x):
ep_encode_output = exported_encode.module()(chunk)
self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))

def test_exported_decoder_xnnpack(self):
class MimiDecode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi

def forward(self, x):
return self.mimi_model.decode(x)

sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
chunk = sample_pcm[..., 0:pcm_chunk_size]
input = self.mimi.encode(chunk)

mimi_decode = MimiDecode(self.mimi)
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
quantization_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=True,
)
quantizer = XNNPACKQuantizer()
quantizer.set_global(quantization_config)
m = exported_decode.module()
m = prepare_pt2e(m, quantizer)
m(input)
m = convert_pt2e(m)
print("quantized graph:")
print(m.graph)
# Export quantized module
exported_decode: ExportedProgram = export(m, (input,), strict=False)
# Lower
edge_manager = to_edge_transform_and_lower(
exported_decode,
partitioner=[XnnpackPartitioner()],
)
print("delegate graph:")
print_delegation_info(edge_manager.exported_program().graph_module)
exec_prog = edge_manager.to_executorch()
output_file = "/tmp/mimi_decode.pte"
with open(output_file, "wb") as file:
exec_prog.write_to_file(file)

eager_res = mimi_decode(input)
runtime = Runtime.get()
program = runtime.load_program(output_file)
method = program.load_method("forward")
flattened_x = tree_flatten(input)[0]
res = method.execute(flattened_x)
# Compare results
sqnr = compute_sqnr(eager_res, res[0])
print(f"SQNR: {sqnr}")
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
unittest.main()
Loading