Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Model Optimizer Changelog (Linux)
=================================

0.40 (2025-12-xx)
^^^^^^^^^^^^^^^^^

**New Features**

- Add support for PyTorch Geometric quantization.

0.39 (2025-11-07)
^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 4 additions & 0 deletions modelopt/torch/quantization/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- :meth:`huggingface<modelopt.torch.quantization.plugins.huggingface>`
- :meth:`megatron<modelopt.torch.quantization.plugins.megatron>`
- :meth:`peft<modelopt.torch.quantization.plugins.peft>`
- :meth:`pytorch_geometric<modelopt.torch.quantization.plugins.pytorch_geometric>`
- :meth:`transformer_engine<modelopt.torch.quantization.plugins.transformer_engine>`
"""

Expand Down Expand Up @@ -57,6 +58,9 @@
with import_plugin("peft"):
from .peft import *

with import_plugin("torch_geometric"):
from .pytorch_geometric import *

with import_plugin("transformer_engine"):
from .transformer_engine import *

Expand Down
89 changes: 89 additions & 0 deletions modelopt/torch/quantization/plugins/pytorch_geometric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""PyTorch Geometric quantization plugin.

This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering
PyG's custom Linear layer with ModelOpt's quantization registry.

Example:
>>> import modelopt.torch.quantization as mtq
>>> from torch_geometric.nn import GATConv
>>>
>>> # Create a model with PyG layers
>>> class GATModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.gat1 = GATConv(10, 64, heads=4)
... self.gat2 = GATConv(64 * 4, 32, heads=1)
>>> model = GATModel()
>>> # PyG layers are now automatically quantizable!
>>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
"""

import torch
from torch_geometric.nn.dense.linear import Linear as PyGLinear

from modelopt.torch.quantization.nn.modules.quant_module import (
QuantLinearConvBase,
QuantModuleRegistry,
)
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW


class QuantPyGLinear(QuantLinearConvBase):
"""Quantized version of PyTorch Geometric's Linear layer.

PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
torch.nn.Linear but has a different API (in_channels/out_channels instead of
in_features/out_features). This class enables quantization of PyG Linear layers.

Note:
Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
internally use PyG Linear layers, so registering this class enables quantization
for a wide range of graph neural network layers.
"""

default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW

def forward(self, input, *args, **kwargs):
"""Forward pass with quantization.

Args:
input: Input tensor to the linear layer
*args: Additional positional arguments
**kwargs: Additional keyword arguments

Returns:
Quantized output tensor
"""
# Quantize input activations
input_q = self.input_quantizer(input)

# Quantize weights
weight_q = self.weight_quantizer(self.weight)

# Perform linear operation
output = torch.nn.functional.linear(
input_q,
weight_q,
self.bias if hasattr(self, "bias") and self.bias is not None else None,
)

# Quantize output (typically disabled by default)
return self.output_quantizer(output)
Comment on lines +61 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This might not be needed

Suggested change
def forward(self, input, *args, **kwargs):
"""Forward pass with quantization.
Args:
input: Input tensor to the linear layer
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Quantized output tensor
"""
# Quantize input activations
input_q = self.input_quantizer(input)
# Quantize weights
weight_q = self.weight_quantizer(self.weight)
# Perform linear operation
output = torch.nn.functional.linear(
input_q,
weight_q,
self.bias if hasattr(self, "bias") and self.bias is not None else None,
)
# Quantize output (typically disabled by default)
return self.output_quantizer(output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, see the definition for ConvNd layers:

@QuantModuleRegistry.register({nn.Conv1d: "nn.Conv1d"})

The forward path should work (inherited from QuantLinearConvBase)



QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"pytest-timeout",
"timm",
"torchvision",
"torch-geometric",
"tox>4.18",
"tox-current-env>=0.0.12",
],
Expand Down
174 changes: 174 additions & 0 deletions tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for PyTorch Geometric quantization plugin."""

import pytest
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv

import modelopt.torch.quantization as mtq


class TestPyTorchGeometricPlugin:
"""Test PyTorch Geometric quantization support."""

@pytest.fixture
def device(self):
"""Get test device."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"):
"""Create sample graph data for testing."""
x = torch.randn(batch_size * num_nodes, in_channels, device=device)
# Create batch assignment
batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)])

# Create edge indices for each graph
edge_list = []
offset = 0
for _ in range(batch_size):
# Create random edges within each graph
src = torch.randint(0, num_nodes, (50,), device=device) + offset
dst = torch.randint(0, num_nodes, (50,), device=device) + offset
edge_list.append(torch.stack([src, dst]))
offset += num_nodes

edge_index = torch.cat(edge_list, dim=1)
edge_attr = torch.randn(edge_index.size(1), 32, device=device)

return x, edge_index, edge_attr, batch

def test_gat_conv_quantization(self, device):
"""Test GATConv layer quantization."""

class GATModel(nn.Module):
def __init__(self):
super().__init__()
self.gat1 = GATConv(16, 64, heads=4, edge_dim=32)
self.gat2 = GATConv(256, 32, heads=1, edge_dim=32)

def forward(self, x, edge_index, edge_attr):
x = torch.relu(self.gat1(x, edge_index, edge_attr))
return self.gat2(x, edge_index, edge_attr)

model = GATModel().to(device)

# Calibration function
def calibrate(m):
m.eval()
with torch.no_grad():
for _ in range(5):
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
_ = m(x, edge_index, edge_attr)

# Quantize model
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)

# Verify quantization
quantizer_count = sum(
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
)
assert quantizer_count > 0, "No quantizers were inserted"

# Test forward pass
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
with torch.no_grad():
output = quantized(x, edge_index, edge_attr)
assert output is not None

def test_multiple_layer_types(self, device):
"""Test quantization of multiple PyG layer types."""

class MultiLayerGNN(nn.Module):
def __init__(self):
super().__init__()
self.gcn = GCNConv(16, 32)
self.sage = SAGEConv(32, 64)
self.transformer = TransformerConv(64, 32, heads=2)

def forward(self, x, edge_index):
x = torch.relu(self.gcn(x, edge_index))
x = torch.relu(self.sage(x, edge_index))
return self.transformer(x, edge_index)

model = MultiLayerGNN().to(device)

# Calibration
def calibrate(m):
m.eval()
with torch.no_grad():
for _ in range(3):
x = torch.randn(50, 16, device=device)
edge_index = torch.randint(0, 50, (2, 100), device=device)
_ = m(x, edge_index)

# Quantize
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)

# Check that PyG Linear layers were quantized
pyg_linear_count = 0
for name, module in model.named_modules():
if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)):
pyg_linear_count += 1

quantizer_count = sum(
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
)

# Each PyG linear should have at least 2 quantizers (input, weight)
assert quantizer_count >= pyg_linear_count * 2, (
f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}"
)

def test_quantization_accuracy(self, device):
"""Test that quantization maintains reasonable accuracy."""
model = GATConv(16, 32, heads=2, edge_dim=16).to(device)

# Create test data
x, edge_index, edge_attr, _ = self.create_graph_data(
batch_size=1, in_channels=16, device=device
)
edge_attr = edge_attr[:, :16] # Match edge_dim

# Get original output
model.eval()
with torch.no_grad():
original_output = model(x, edge_index, edge_attr)

# Calibration
def calibrate(m):
m.eval()
with torch.no_grad():
_ = m(x, edge_index, edge_attr)

# Quantize
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)

# Get quantized output
with torch.no_grad():
quantized_output = quantized(x, edge_index, edge_attr)

# Check relative error
abs_diff = torch.abs(original_output - quantized_output)
relative_error = abs_diff / (torch.abs(original_output) + 1e-8)
mean_relative_error = relative_error.mean().item()

assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"


if __name__ == "__main__":
pytest.main([__file__])
Comment on lines +173 to +174
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

For debugging, I directly run the unitests using VSCode/Cursor test explorer.

Suggested change
if __name__ == "__main__":
pytest.main([__file__])

Loading