-
Notifications
You must be signed in to change notification settings - Fork 0
[Tutorial] OpenVINOQuantizer #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
f0ab805
acf1647
5b1c99a
b2eaa82
810899a
82a47a5
26f044b
75d3549
e8e94d3
f09a85f
2c766e7
b424f92
f3137be
b7d2781
bb3c2f8
c093c76
ccc02d6
090823f
71695c7
35c68ea
a5632da
0a422c2
7fc877b
bdeca26
1988e26
70d2154
7e97977
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,211 @@ | ||||||||
| PyTorch 2 Export Quantization with OpenVINO backend | ||||||||
| =========================================================================== | ||||||||
|
|
||||||||
| **Author**: dlyakhov, asuslov, aamir, # TODO: add required authors | ||||||||
|
|
||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||||
| Introduction | ||||||||
| -------------- | ||||||||
|
|
||||||||
| This tutorial introduces the steps for utilizing the `Neural Network Compression Framework (nncf) <https://github.com/openvinotoolkit/nncf/tree/develop>`_ to generate a quantized model customized | ||||||||
| for the `OpenVINO torch.compile backend <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_ and explains how to lower the quantized model into the `OpenVINO <https://docs.openvino.ai/2024/index.html>`_ representation. | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and performs quantization transformations on top of the ATen graph. | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
| This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. | ||||||||
| OpenVINO is the new backend that compiles the FX Graph generated by TorchDynamo into an optimized OpenVINO model. | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| The quantization flow mainly includes four steps: | ||||||||
|
|
||||||||
| - Step 1: Install OpenVINO and NNCF. | ||||||||
|
||||||||
| - Step 2: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_. | ||||||||
| - Step 3: Apply the Quantization flow based on the captured FX Graph. | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
| - Step 4: Lower the quantized model into OpenVINO representation with the API ``torch.compile``. | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| The high-level architecture of this flow could look like this: | ||||||||
|
|
||||||||
| :: | ||||||||
|
|
||||||||
| float_model(Python) Example Input | ||||||||
| \ / | ||||||||
| \ / | ||||||||
| —-------------------------------------------------------- | ||||||||
| | export | | ||||||||
| —-------------------------------------------------------- | ||||||||
| | | ||||||||
| FX Graph in ATen | ||||||||
| | | ||||||||
| | OpenVINOQuantizer | ||||||||
| | / | ||||||||
| —-------------------------------------------------------- | ||||||||
| | prepare_pt2e | | ||||||||
| | | | | ||||||||
| | Calibrate | ||||||||
| | | | | ||||||||
| | convert_pt2e | | ||||||||
| —-------------------------------------------------------- | ||||||||
| | | ||||||||
| Quantized Model | ||||||||
| | | ||||||||
| —-------------------------------------------------------- | ||||||||
| | Lower into Inductor | | ||||||||
| —-------------------------------------------------------- | ||||||||
| | | ||||||||
| OpenVINO model | ||||||||
|
|
||||||||
| Post Training Quantization | ||||||||
| ---------------------------- | ||||||||
|
|
||||||||
| Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_ | ||||||||
| for post training quantization. | ||||||||
|
|
||||||||
| 1. OpenVINO and NNCF installation | ||||||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||||
| OpenVINO and NNCF could be easily installed via `pip distribution <https://docs.openvino.ai/2024/get-started/install-openvino.html>`_: | ||||||||
|
|
||||||||
| .. code-block:: bash | ||||||||
|
|
||||||||
| pip install -U pip | ||||||||
| pip install openvino, nncf | ||||||||
|
|
||||||||
|
|
||||||||
| 2. Capture FX Graph | ||||||||
| ^^^^^^^^^^^^^^^^^^^^^ | ||||||||
|
|
||||||||
| We will start by performing the necessary imports, capturing the FX Graph from the eager module. | ||||||||
|
|
||||||||
| .. code-block:: python | ||||||||
|
|
||||||||
| import copy | ||||||||
| import openvino.torch | ||||||||
| import torch | ||||||||
| import torchvision.models as models | ||||||||
| from torch.ao.quantization.quantize_pt2e import convert_pt2e | ||||||||
| from torch.ao.quantization.quantize_pt2e import prepare_pt2e | ||||||||
| from torch.ao.quantization.quantizer.openvino_quantizer import OpenVINOQuantizer | ||||||||
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
|
||||||||
| import nncf | ||||||||
| from nncf.torch import disable_patching | ||||||||
|
||||||||
| import nncf | |
| from nncf.torch import disable_patching | |
| import nncf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, that does not work. We can do import nncf.torch and then do nncf.torch.disable_patching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import nncf.torch is introduced, please check
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need the memory format to be channels_last?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a copy past from the original tutorial, removed, thanks!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is disable_patching() needed both during export and inference with torch.compile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately yes: export will fail with an error and performance of the compiled model will be ruined without it
daniil-lyakhov marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
daniil-lyakhov marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where can I find more information about OpenVINOQuantizer parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question, we don't have a dedicated page about the OpenVINOQuantizer yet. We have a dedicated page for the nncf.quantize and its parameters, but the subset of parameters is not equivalent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a link to nncf API docs, which should be updated with this PR: openvinotoolkit/nncf#3277
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to add somethink like that:
For more information about NNCF and NNCF Quantization Flow for PyTorch models, please visit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, please check
Uh oh!
There was an error while loading. Please reload this page.