-
Notifications
You must be signed in to change notification settings - Fork 192
PyTorch geometric quantization support #494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
i-riyad
wants to merge
1
commit into
main
Choose a base branch
from
rislam/pytorch-geometric-quant
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """PyTorch Geometric quantization plugin. | ||
|
|
||
| This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering | ||
| PyG's custom Linear layer with ModelOpt's quantization registry. | ||
|
|
||
| Example: | ||
| >>> import modelopt.torch.quantization as mtq | ||
| >>> from torch_geometric.nn import GATConv | ||
| >>> | ||
| >>> # Create a model with PyG layers | ||
| >>> class GATModel(nn.Module): | ||
| ... def __init__(self): | ||
| ... super().__init__() | ||
| ... self.gat1 = GATConv(10, 64, heads=4) | ||
| ... self.gat2 = GATConv(64 * 4, 32, heads=1) | ||
| >>> model = GATModel() | ||
| >>> # PyG layers are now automatically quantizable! | ||
| >>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) | ||
| """ | ||
|
|
||
| import torch | ||
| from torch_geometric.nn.dense.linear import Linear as PyGLinear | ||
|
|
||
| from modelopt.torch.quantization.nn.modules.quant_module import ( | ||
| QuantLinearConvBase, | ||
| QuantModuleRegistry, | ||
| ) | ||
| from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW | ||
|
|
||
|
|
||
| class QuantPyGLinear(QuantLinearConvBase): | ||
| """Quantized version of PyTorch Geometric's Linear layer. | ||
|
|
||
| PyTorch Geometric uses a custom Linear layer that is functionally equivalent to | ||
| torch.nn.Linear but has a different API (in_channels/out_channels instead of | ||
| in_features/out_features). This class enables quantization of PyG Linear layers. | ||
|
|
||
| Note: | ||
| Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.) | ||
| internally use PyG Linear layers, so registering this class enables quantization | ||
| for a wide range of graph neural network layers. | ||
| """ | ||
|
|
||
| default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW | ||
|
|
||
| def forward(self, input, *args, **kwargs): | ||
| """Forward pass with quantization. | ||
|
|
||
| Args: | ||
| input: Input tensor to the linear layer | ||
| *args: Additional positional arguments | ||
| **kwargs: Additional keyword arguments | ||
|
|
||
| Returns: | ||
| Quantized output tensor | ||
| """ | ||
| # Quantize input activations | ||
| input_q = self.input_quantizer(input) | ||
|
|
||
| # Quantize weights | ||
| weight_q = self.weight_quantizer(self.weight) | ||
|
|
||
| # Perform linear operation | ||
| output = torch.nn.functional.linear( | ||
| input_q, | ||
| weight_q, | ||
| self.bias if hasattr(self, "bias") and self.bias is not None else None, | ||
| ) | ||
|
|
||
| # Quantize output (typically disabled by default) | ||
| return self.output_quantizer(output) | ||
|
|
||
|
|
||
| QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,174 @@ | ||||||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| """Tests for PyTorch Geometric quantization plugin.""" | ||||||
|
|
||||||
| import pytest | ||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
| from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv | ||||||
|
|
||||||
| import modelopt.torch.quantization as mtq | ||||||
|
|
||||||
|
|
||||||
| class TestPyTorchGeometricPlugin: | ||||||
| """Test PyTorch Geometric quantization support.""" | ||||||
|
|
||||||
| @pytest.fixture | ||||||
| def device(self): | ||||||
| """Get test device.""" | ||||||
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||
|
|
||||||
| def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"): | ||||||
| """Create sample graph data for testing.""" | ||||||
| x = torch.randn(batch_size * num_nodes, in_channels, device=device) | ||||||
| # Create batch assignment | ||||||
| batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)]) | ||||||
|
|
||||||
| # Create edge indices for each graph | ||||||
| edge_list = [] | ||||||
| offset = 0 | ||||||
| for _ in range(batch_size): | ||||||
| # Create random edges within each graph | ||||||
| src = torch.randint(0, num_nodes, (50,), device=device) + offset | ||||||
| dst = torch.randint(0, num_nodes, (50,), device=device) + offset | ||||||
| edge_list.append(torch.stack([src, dst])) | ||||||
| offset += num_nodes | ||||||
|
|
||||||
| edge_index = torch.cat(edge_list, dim=1) | ||||||
| edge_attr = torch.randn(edge_index.size(1), 32, device=device) | ||||||
|
|
||||||
| return x, edge_index, edge_attr, batch | ||||||
|
|
||||||
| def test_gat_conv_quantization(self, device): | ||||||
| """Test GATConv layer quantization.""" | ||||||
|
|
||||||
| class GATModel(nn.Module): | ||||||
| def __init__(self): | ||||||
| super().__init__() | ||||||
| self.gat1 = GATConv(16, 64, heads=4, edge_dim=32) | ||||||
| self.gat2 = GATConv(256, 32, heads=1, edge_dim=32) | ||||||
|
|
||||||
| def forward(self, x, edge_index, edge_attr): | ||||||
| x = torch.relu(self.gat1(x, edge_index, edge_attr)) | ||||||
| return self.gat2(x, edge_index, edge_attr) | ||||||
|
|
||||||
| model = GATModel().to(device) | ||||||
|
|
||||||
| # Calibration function | ||||||
| def calibrate(m): | ||||||
| m.eval() | ||||||
| with torch.no_grad(): | ||||||
| for _ in range(5): | ||||||
| x, edge_index, edge_attr, _ = self.create_graph_data(device=device) | ||||||
| _ = m(x, edge_index, edge_attr) | ||||||
|
|
||||||
| # Quantize model | ||||||
| quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) | ||||||
|
|
||||||
| # Verify quantization | ||||||
| quantizer_count = sum( | ||||||
| 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() | ||||||
| ) | ||||||
| assert quantizer_count > 0, "No quantizers were inserted" | ||||||
|
|
||||||
| # Test forward pass | ||||||
| x, edge_index, edge_attr, _ = self.create_graph_data(device=device) | ||||||
| with torch.no_grad(): | ||||||
| output = quantized(x, edge_index, edge_attr) | ||||||
| assert output is not None | ||||||
|
|
||||||
| def test_multiple_layer_types(self, device): | ||||||
| """Test quantization of multiple PyG layer types.""" | ||||||
|
|
||||||
| class MultiLayerGNN(nn.Module): | ||||||
| def __init__(self): | ||||||
| super().__init__() | ||||||
| self.gcn = GCNConv(16, 32) | ||||||
| self.sage = SAGEConv(32, 64) | ||||||
| self.transformer = TransformerConv(64, 32, heads=2) | ||||||
|
|
||||||
| def forward(self, x, edge_index): | ||||||
| x = torch.relu(self.gcn(x, edge_index)) | ||||||
| x = torch.relu(self.sage(x, edge_index)) | ||||||
| return self.transformer(x, edge_index) | ||||||
|
|
||||||
| model = MultiLayerGNN().to(device) | ||||||
|
|
||||||
| # Calibration | ||||||
| def calibrate(m): | ||||||
| m.eval() | ||||||
| with torch.no_grad(): | ||||||
| for _ in range(3): | ||||||
| x = torch.randn(50, 16, device=device) | ||||||
| edge_index = torch.randint(0, 50, (2, 100), device=device) | ||||||
| _ = m(x, edge_index) | ||||||
|
|
||||||
| # Quantize | ||||||
| quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) | ||||||
|
|
||||||
| # Check that PyG Linear layers were quantized | ||||||
| pyg_linear_count = 0 | ||||||
| for name, module in model.named_modules(): | ||||||
| if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)): | ||||||
| pyg_linear_count += 1 | ||||||
|
|
||||||
| quantizer_count = sum( | ||||||
| 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() | ||||||
| ) | ||||||
|
|
||||||
| # Each PyG linear should have at least 2 quantizers (input, weight) | ||||||
| assert quantizer_count >= pyg_linear_count * 2, ( | ||||||
| f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}" | ||||||
| ) | ||||||
|
|
||||||
| def test_quantization_accuracy(self, device): | ||||||
| """Test that quantization maintains reasonable accuracy.""" | ||||||
| model = GATConv(16, 32, heads=2, edge_dim=16).to(device) | ||||||
|
|
||||||
| # Create test data | ||||||
| x, edge_index, edge_attr, _ = self.create_graph_data( | ||||||
| batch_size=1, in_channels=16, device=device | ||||||
| ) | ||||||
| edge_attr = edge_attr[:, :16] # Match edge_dim | ||||||
|
|
||||||
| # Get original output | ||||||
| model.eval() | ||||||
| with torch.no_grad(): | ||||||
| original_output = model(x, edge_index, edge_attr) | ||||||
|
|
||||||
| # Calibration | ||||||
| def calibrate(m): | ||||||
| m.eval() | ||||||
| with torch.no_grad(): | ||||||
| _ = m(x, edge_index, edge_attr) | ||||||
|
|
||||||
| # Quantize | ||||||
| quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) | ||||||
|
|
||||||
| # Get quantized output | ||||||
| with torch.no_grad(): | ||||||
| quantized_output = quantized(x, edge_index, edge_attr) | ||||||
|
|
||||||
| # Check relative error | ||||||
| abs_diff = torch.abs(original_output - quantized_output) | ||||||
| relative_error = abs_diff / (torch.abs(original_output) + 1e-8) | ||||||
| mean_relative_error = relative_error.mean().item() | ||||||
|
|
||||||
| assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}" | ||||||
|
|
||||||
|
|
||||||
| if __name__ == "__main__": | ||||||
| pytest.main([__file__]) | ||||||
|
Comment on lines
+173
to
+174
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: For debugging, I directly run the unitests using VSCode/Cursor test explorer.
Suggested change
|
||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This might not be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, see the definition for ConvNd layers:
TensorRT-Model-Optimizer/modelopt/torch/quantization/nn/modules/quant_conv.py
Line 39 in ed58324
The forward path should work (inherited from
QuantLinearConvBase)