Skip to content

Failed to prepare model for QAT in Pt2e #801

@jimmy133719

Description

@jimmy133719

Description of the bug:

I met the same issue as #355 , but still can't find the way to solve it, but I didn't use SiLU or GELU.

Environment:

  • torch==2.5.1+cu124
  • ai-edge-torch==0.4.0
import torch, numpy as np
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e, convert_pt2e

import ai_edge_torch
from ai_edge_torch.quantize import pt2e_quantizer, quant_config

class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn   = nn.BatchNorm2d(16)
        self.act  = nn.ReLU(inplace=False)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc   = nn.Linear(16, num_classes)

    def forward(self, x):
        x = self.act(self.bn(self.conv(x)))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

model = SmallCNN().eval()
example_inputs = (torch.randn(1, 3, 224, 224),)

exported = capture_pre_autograd_graph(model, example_inputs)

quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
    pt2e_quantizer.get_symmetric_quantization_config(
        is_per_channel=True,  
        is_dynamic=False,     
        is_qat=True           
    )
)

qat_model = prepare_qat_pt2e(exported, quantizer)  

Actual vs expected behavior:

Traceback (most recent call last):
  File "/media/ipcam/RTK-IPCAM/jimmy/MediaPipePyTorch/qat_example.py", line 38, in <module>
    qat_model = prepare_qat_pt2e(exported, quantizer)
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/ao/quantization/quantize_pt2e.py", line 169, in prepare_qat_pt2e
    quantizer.annotate(model)
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 384, in annotate
    model = self._annotate_for_static_quantization_config(model)
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 432, in _annotate_for_static_quantization_config
    self._annotate_all_static_patterns(
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer.py", line 399, in _annotate_all_static_patterns
    OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer_utils.py", line 461, in _annotate_conv_bn_relu
    return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/ai_edge_torch/quantize/pt2e_quantizer_utils.py", line 517, in _do_annotate_conv_bn
    pattern = _get_aten_graph_module_for_pattern(
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/ao/quantization/pt2e/utils.py", line 375, in _get_aten_graph_module_for_pattern
    aten_pattern = capture_pre_autograd_graph(
  File "/home/ipcam/anaconda3/envs/ai_edge_torch/lib/python3.10/site-packages/torch/_export/__init__.py", line 123, in capture_pre_autograd_graph
    assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
AssertionError: Expected an nn.Module instance.

Any other information you'd like to share?

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions