Skip to content

QAT model (with Q/DQ, explicit quant) slower than PTQ model (no Q/DQ, implicit quant) #4412

@BowenXu2020

Description

@BowenXu2020

Description

Hello, please let me describe my question:
I tried the 2 methods of quantization:

  1. PTQ (no Q/DQ node in onnx model, implicit quantization)
  2. QAT (with Q/DQ nodes in onnx model, explicit quantization)

both trt models show reasonable accuracy, but the QAT one has slower speed.

I looked through the related issues, but mostly about dealing with specific op and the solution is about manually place Q/DQ nodes.
I'm now using the automatic method to insert Q/DQ. (for pytorch_quantization, the monkey patch; for modelopt, the modelopt.torch.quantization.quantize(model, config, calib_loop). )
I'm a freshman in this field and currently don't quite understand the principle of how to place Q/DQ correctly, especially for a very complicated model. So, I did a test on a simple toy model and hope the automatic method can work well. However, the speed conclusion is the same. QAT model is very slow after quantization.
My toy model is like CBR + CBR + AdaptAvgPool + flatten (BCHW to BC') + FC. I'm using modelopt automatic method and the QAT onnx model looks like below:

Image

Image

For trt model, I did profiling to check each layer's time consumption. The PTQ trt model is good:

Image

However, the QAT trt model is much slower:

Image

The conv2 and the pooling layer (__myl_MulAddRel...) take most of the time. Part of the trt model is like below:

Image

My questions are:

  1. in QAT trt model, is "conv.weight + /conv/weight_quantizer/QuantizeLinear + /conv/Conv" a reasonable format? It seems like there's a redundant Q step on weight, which means weights are fp32 + Q = int8, but not like in PTQ trt model where the weights are saved as int8 directly. But in my opinion, here might not be the issue, because only conv2 costs long time, conv1 is all good. Also, the dtype of CaskConvolution (both conv1&2) weights is really int8.

Image

  1. then, in the above trt model picture, the output of CaskConv (conv2) is fp32, and there's some kgen, which I don't quite understand. I think they're the reason why conv2 and "__myl_MulAddRel..." cost much longer time. If so, I'd like to ask what's the possible reason of this kind of issue? Because it's a very simple network, I think the automatic method should be able to handle it, maybe I made some silly mistakes...

  2. One solution idea says "manually insert Q/DQ according to the PTQ quantization graph, therefore making the QAT quantization result as close to PTQ as possible". I'd like to ask, technically, is it possible? I mean, the trt graph of QAT and PTQ seems quite different. I don't know whether the underlying technical logic for quantization is the same for with/without Q/DQ. For example, for the toy model, is it possible to make the QAT and PTQ trt graphs the same by manually placing Q/DQ to the correct position?

full QAT model trt graph

Image

Image

Image

full PTQ model trt graph

Image

Image

Image

Really appreciate your time and help!!

Environment

TensorRT Version: 10.7.0

Platform: Jetson Orin NX

Jetpack: 6.2

CUDA Version: 12.0

Operating System: linux

Python Version (if applicable): 3.9

PyTorch Version (if applicable): 2.6.0

Materials

onnx model: https://drive.google.com/drive/folders/17vbf_YIUUQIyEUjEBZ38aedcuqhggYks?usp=sharing
convert to trt:
trtexec --onnx=toy_qat_model_sim.onnx --saveEngine=toy_qat_model_sim.trt --int8 --profilingVerbosity=detailed
export onnx opset_version: 13

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions