-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Labels
type:supportFor use-related issuesFor use-related issues
Description
I am trying to get a basic example working that converts a simple pytorch model to a quantized tflite file. I am using python3.12 with pytorch2.7.0 and ai-edge-torch0.5.0 and created the below code. I tried following this and I also tried following the quantization section located here, but I would normally end up with an error like error: 'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<10x5xi8>'. There must be something wrong with my quantizer and I show a couple different attempts in the code, but none of them worked. I am just wondering if anyone can duplicate my problem and maybe point out what I am doing wrong. Thanks!
import torch
import torch.ao.quantization
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from ai_edge_torch import convert
from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
# from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
# get_symmetric_quantization_config,
# XNNPACKQuantizer,
# )
from ai_edge_torch.quantize.quant_config import QuantConfig
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
example_inputs = torch.randn(1, 5)
m = M().eval()
m = torch.export.export(m, (example_inputs,)).module()
# Quantizer Attempt1
quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
# get_symmetric_quantization_config()
)
# # Quantizer Attempt2
# quantizer = X86InductorQuantizer()
# quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
# # Quantizer Attempt3
# quantizer = XNNPACKQuantizer()
# quantizer.set_global(get_symmetric_quantization_config())
m = prepare_pt2e(m, quantizer)
with torch.no_grad():
m(example_inputs) # Use your actual calibration data here
m = convert_pt2e(m)
edge_model = convert(m, (example_inputs,))
# edge_model = convert(m, (example_inputs,), quant_config=QuantConfig(pt2e_quantizer=quantizer))
edge_model.export("out/Test.tflite") # Export to TFLite format
print("Done")Metadata
Metadata
Assignees
Labels
type:supportFor use-related issuesFor use-related issues