Skip to content

Conversation

daniil-lyakhov
Copy link
Owner

Fixes #ISSUE_NUMBER

Description

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

@daniil-lyakhov daniil-lyakhov force-pushed the dl/fx/openvino_quantizer branch 8 times, most recently from 4b67782 to acf1647 Compare January 28, 2025 19:05
@daniil-lyakhov daniil-lyakhov changed the title Dl/fx/openvino quantizer [Tutorial] OpenVINOQuantizer Jan 28, 2025

# Create the data, using the dummy data here as an example
traced_bs = 50
x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the memory format to be channels_last?

Copy link
Owner Author

@daniil-lyakhov daniil-lyakhov Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy past from the original tutorial, removed, thanks!

example_inputs = (x,)

# Capture the FX Graph to be quantized
with torch.no_grad(), disable_patching():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is disable_patching() needed both during export and inference with torch.compile?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately yes: export will fail with an error and performance of the compiled model will be ruined without it

===========================================================================

**Author**: dlyakhov, asuslov, aamir, # TODO: add required authors

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 85 to 86
import nncf
from nncf.torch import disable_patching
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import nncf
from nncf.torch import disable_patching
import nncf

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, that does not work. We can do import nncf.torch and then do nncf.torch.disable_patching

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import nncf.torch is introduced, please check

# from input to output nodes will be excluded from the quantization process.
subgraph = nncf.Subgraph(inputs=['layer_1', 'layer_2'], outputs=['layer_3'])
OpenVINOQuantizer(ignored_scope=nncf.IgnoredScope(subgraphs=[subgraph]))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find more information about OpenVINOQuantizer parameters?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question, we don't have a dedicated page about the OpenVINOQuantizer yet. We have a dedicated page for the nncf.quantize and its parameters, but the subset of parameters is not equivalent

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a link to nncf API docs, which should be updated with this PR: openvinotoolkit/nncf#3277

Conclusion
------------

With this tutorial, we introduce how to use torch.compile with the OpenVINO backend and the OpenVINO quantizer.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to add somethink like that:
For more information about NNCF and NNCF Quantization Flow for PyTorch models, please visit

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please check

Co-authored-by: Alexander Suslov <[email protected]>
Co-authored-by: Yamini Nimmagadda <[email protected]>
@daniil-lyakhov daniil-lyakhov force-pushed the dl/fx/openvino_quantizer branch 3 times, most recently from f4f592f to af4eb02 Compare February 24, 2025 12:33
@daniil-lyakhov daniil-lyakhov force-pushed the dl/fx/openvino_quantizer branch from af4eb02 to 810899a Compare February 24, 2025 12:35

The quantization flow mainly includes four steps:

- Step 1: Install OpenVINO and NNCF.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the quantization flow itself does not includer step 1. It is just a prerequisite.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, fixed

Introduction
--------------

This tutorial demonstrates how to use `OpenVINOQuantizer` from `Neural Network Compression Framework (NNCF) <https://github.com/openvinotoolkit/nncf/tree/develop>`_ in PyTorch 2 Export Quantization flow to generate a quantized model customized for the `OpenVINO torch.compile backend <https://docs.openvino.ai/2024/openvino-workflow/torch-compile.html>`_ and explains how to lower the quantized model into the `OpenVINO <https://docs.openvino.ai/2024/index.html>`_ representation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more attractive if to give the user an idea why it may need to use OpenVINOQuantizer (e.g. it is more accurate, performant, etc.)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense! Description of advantages of OpenVINOQuantizer was added

@daniil-lyakhov daniil-lyakhov force-pushed the dl/fx/openvino_quantizer branch from 1c6bc7c to f09a85f Compare April 14, 2025 13:58
daniil-lyakhov and others added 18 commits April 14, 2025 16:05
Co-authored-by: Svetlana Karslioglu <[email protected]>
Removing the docs survey banner
Fix code snippet format issue in inductor_windows
---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
* Add a note that foreach feature is a prototype
Update the What's New section.

---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
* Adjust torch.compile() best practices

1. Add best practice to prefer `mod.compile` over `torch.compile(mod)`, which avoids `_orig_` naming problems.
Repro steps:
- opt_mod = torch.compile(mod)
- train opt_mod
- save checkpoint
In another script, potentially on a machine that does NOT support `torch.compile`: load checkpoint.
This fails with an error, because the checkpoint on `opt_mod` got its params renamed by `torch.compile`:
```
RuntimeError: Error(s) in loading state_dict for VQVAE:
	Missing key(s) in state_dict: "embedding.weight", "encoder.encoder.net.0.weight", "encoder.encoder.net.0.bias", ...
	Unexpected key(s) in state_dict: "_orig_mod.embedding.weight", "_orig_mod.encoder.encoder.net.0.weight", "_orig_mod.encoder.encoder.net.0.bias", ...
```
- Add best practice to use, or at least try, `fullgraph=True`. This doesn't always work, but we should encourage it.

---------

Co-authored-by: Svetlana Karslioglu <[email protected]>
Co-authored-by: Svetlana Karslioglu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants