-
Notifications
You must be signed in to change notification settings - Fork 122
Open
Description
Description of the bug:
I met the same issue as #355 , but still can't find the way to solve it, but I didn't use SiLU or GELU.
Environment:
- torch==2.5.1+cu124
- ai-edge-torch==0.4.0
import torch, numpy as np
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e
import ai_edge_torch
from ai_edge_torch.quantize import pt2e_quantizer, quant_config
class SmallCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, 1, 1)
self.bn = nn.BatchNorm2d(16)
self.act = nn.ReLU(inplace=False)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, num_classes)
def forward(self, x):
x = self.act(self.bn(self.conv(x)))
x = self.pool(x)
x = torch.flatten(x, 1)
return self.fc(x)
model = SmallCNN().eval()
example_inputs = (torch.randn(1, 3, 224, 224),)
exported = capture_pre_autograd_graph(model, example_inputs)
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
pt2e_quantizer.get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=False,
is_qat=True
)
)
qat_model = prepare_qat_pt2e(exported, quantizer) Actual vs expected behavior:
Traceback (most recent call last):
File "/media/ipcam/RTK-IPCAM/jimmy/MediaPipePyTorch/qat_example.py", line 38, in <module>
qat_model = prepare_qat_pt2e(exported, quantizer)
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/ao/quantization/quantize_pt2e.py", line 169, in prepare_qat_pt2e
quantizer.annotate(model)
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 384, in annotate
model = self._annotate_for_static_quantization_config(model)
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 432, in _annotate_for_static_quantization_config
self._annotate_all_static_patterns(
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 399, in _annotate_all_static_patterns
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer_utils.py", line 461, in _annotate_conv_bn_relu
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer_utils.py", line 517, in _do_annotate_conv_bn
pattern = _get_aten_graph_module_for_pattern(
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/ao/quantization/pt2e/utils.py", line 375, in _get_aten_graph_module_for_pattern
aten_pattern = capture_pre_autograd_graph(
File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/_export/__init__.py", line 123, in capture_pre_autograd_graph
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
AssertionError: Expected an nn.Module instance.Any other information you'd like to share?
No response