Skip to content

Commit 243af33

Browse files
committed
pytorch geometric quantization support
Signed-off-by: Riyad Islam <[email protected]>
1 parent 5f0ef3b commit 243af33

File tree

5 files changed

+274
-0
lines changed

5 files changed

+274
-0
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4+
0.40 (2025-12-09)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**New Features**
8+
9+
- Add support for PyTorch Geometric quantization.
10+
411
0.39 (2025-11-07)
512
^^^^^^^^^^^^^^^^^
613

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- :meth:`huggingface<modelopt.torch.quantization.plugins.huggingface>`
2626
- :meth:`megatron<modelopt.torch.quantization.plugins.megatron>`
2727
- :meth:`peft<modelopt.torch.quantization.plugins.peft>`
28+
- :meth:`pytorch_geometric<modelopt.torch.quantization.plugins.pytorch_geometric>`
2829
- :meth:`transformer_engine<modelopt.torch.quantization.plugins.transformer_engine>`
2930
"""
3031

@@ -57,6 +58,9 @@
5758
with import_plugin("peft"):
5859
from .peft import *
5960

61+
with import_plugin("torch_geometric"):
62+
from .pytorch_geometric import *
63+
6064
with import_plugin("transformer_engine"):
6165
from .transformer_engine import *
6266

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""PyTorch Geometric quantization plugin.
17+
18+
This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering
19+
PyG's custom Linear layer with ModelOpt's quantization registry.
20+
21+
Example:
22+
>>> import modelopt.torch.quantization as mtq
23+
>>> from torch_geometric.nn import GATConv
24+
>>>
25+
>>> # Create a model with PyG layers
26+
>>> class GATModel(nn.Module):
27+
... def __init__(self):
28+
... super().__init__()
29+
... self.gat1 = GATConv(10, 64, heads=4)
30+
... self.gat2 = GATConv(64 * 4, 32, heads=1)
31+
>>> model = GATModel()
32+
>>> # PyG layers are now automatically quantizable!
33+
>>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
34+
"""
35+
36+
import torch
37+
from torch_geometric.nn.dense.linear import Linear as PyGLinear
38+
39+
from modelopt.torch.quantization.nn.modules.quant_module import (
40+
QuantLinearConvBase,
41+
QuantModuleRegistry,
42+
)
43+
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
44+
45+
QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})
46+
47+
48+
class QuantPyGLinear(QuantLinearConvBase):
49+
"""Quantized version of PyTorch Geometric's Linear layer.
50+
51+
PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
52+
torch.nn.Linear but has a different API (in_channels/out_channels instead of
53+
in_features/out_features). This class enables quantization of PyG Linear layers.
54+
55+
Note:
56+
Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
57+
internally use PyG Linear layers, so registering this class enables quantization
58+
for a wide range of graph neural network layers.
59+
"""
60+
61+
default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
62+
63+
def forward(self, input, *args, **kwargs):
64+
"""Forward pass with quantization.
65+
66+
Args:
67+
input: Input tensor to the linear layer
68+
*args: Additional positional arguments
69+
**kwargs: Additional keyword arguments
70+
71+
Returns:
72+
Quantized output tensor
73+
"""
74+
# Quantize input activations
75+
input_q = self.input_quantizer(input)
76+
77+
# Quantize weights
78+
weight_q = self.weight_quantizer(self.weight)
79+
80+
# Perform linear operation
81+
output = torch.nn.functional.linear(
82+
input_q,
83+
weight_q,
84+
self.bias if hasattr(self, "bias") and self.bias is not None else None,
85+
)
86+
87+
# Quantize output (typically disabled by default)
88+
return self.output_quantizer(output)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"pytest-timeout",
8080
"timm",
8181
"torchvision",
82+
"torch-geometric",
8283
"tox>4.18",
8384
"tox-current-env>=0.0.12",
8485
],
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for PyTorch Geometric quantization plugin."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv
22+
23+
import modelopt.torch.quantization as mtq
24+
25+
26+
class TestPyTorchGeometricPlugin:
27+
"""Test PyTorch Geometric quantization support."""
28+
29+
@pytest.fixture
30+
def device(self):
31+
"""Get test device."""
32+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
33+
34+
def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"):
35+
"""Create sample graph data for testing."""
36+
x = torch.randn(batch_size * num_nodes, in_channels, device=device)
37+
# Create batch assignment
38+
batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)])
39+
40+
# Create edge indices for each graph
41+
edge_list = []
42+
offset = 0
43+
for _ in range(batch_size):
44+
# Create random edges within each graph
45+
src = torch.randint(0, num_nodes, (50,), device=device) + offset
46+
dst = torch.randint(0, num_nodes, (50,), device=device) + offset
47+
edge_list.append(torch.stack([src, dst]))
48+
offset += num_nodes
49+
50+
edge_index = torch.cat(edge_list, dim=1)
51+
edge_attr = torch.randn(edge_index.size(1), 32, device=device)
52+
53+
return x, edge_index, edge_attr, batch
54+
55+
def test_gat_conv_quantization(self, device):
56+
"""Test GATConv layer quantization."""
57+
58+
class GATModel(nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
self.gat1 = GATConv(16, 64, heads=4, edge_dim=32)
62+
self.gat2 = GATConv(256, 32, heads=1, edge_dim=32)
63+
64+
def forward(self, x, edge_index, edge_attr):
65+
x = torch.relu(self.gat1(x, edge_index, edge_attr))
66+
return self.gat2(x, edge_index, edge_attr)
67+
68+
model = GATModel().to(device)
69+
70+
# Calibration function
71+
def calibrate(m):
72+
m.eval()
73+
with torch.no_grad():
74+
for _ in range(5):
75+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
76+
_ = m(x, edge_index, edge_attr)
77+
78+
# Quantize model
79+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
80+
81+
# Verify quantization
82+
quantizer_count = sum(
83+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
84+
)
85+
assert quantizer_count > 0, "No quantizers were inserted"
86+
87+
# Test forward pass
88+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
89+
with torch.no_grad():
90+
output = quantized(x, edge_index, edge_attr)
91+
assert output is not None
92+
93+
def test_multiple_layer_types(self, device):
94+
"""Test quantization of multiple PyG layer types."""
95+
96+
class MultiLayerGNN(nn.Module):
97+
def __init__(self):
98+
super().__init__()
99+
self.gcn = GCNConv(16, 32)
100+
self.sage = SAGEConv(32, 64)
101+
self.transformer = TransformerConv(64, 32, heads=2)
102+
103+
def forward(self, x, edge_index):
104+
x = torch.relu(self.gcn(x, edge_index))
105+
x = torch.relu(self.sage(x, edge_index))
106+
return self.transformer(x, edge_index)
107+
108+
model = MultiLayerGNN().to(device)
109+
110+
# Calibration
111+
def calibrate(m):
112+
m.eval()
113+
with torch.no_grad():
114+
for _ in range(3):
115+
x = torch.randn(50, 16, device=device)
116+
edge_index = torch.randint(0, 50, (2, 100), device=device)
117+
_ = m(x, edge_index)
118+
119+
# Quantize
120+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
121+
122+
# Check that PyG Linear layers were quantized
123+
pyg_linear_count = 0
124+
for name, module in model.named_modules():
125+
if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)):
126+
pyg_linear_count += 1
127+
128+
quantizer_count = sum(
129+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
130+
)
131+
132+
# Each PyG linear should have at least 2 quantizers (input, weight)
133+
assert quantizer_count >= pyg_linear_count * 2, (
134+
f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}"
135+
)
136+
137+
def test_quantization_accuracy(self, device):
138+
"""Test that quantization maintains reasonable accuracy."""
139+
model = GATConv(16, 32, heads=2, edge_dim=16).to(device)
140+
141+
# Create test data
142+
x, edge_index, edge_attr, _ = self.create_graph_data(
143+
batch_size=1, in_channels=16, device=device
144+
)
145+
edge_attr = edge_attr[:, :16] # Match edge_dim
146+
147+
# Get original output
148+
model.eval()
149+
with torch.no_grad():
150+
original_output = model(x, edge_index, edge_attr)
151+
152+
# Calibration
153+
def calibrate(m):
154+
m.eval()
155+
with torch.no_grad():
156+
_ = m(x, edge_index, edge_attr)
157+
158+
# Quantize
159+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
160+
161+
# Get quantized output
162+
with torch.no_grad():
163+
quantized_output = quantized(x, edge_index, edge_attr)
164+
165+
# Check relative error
166+
abs_diff = torch.abs(original_output - quantized_output)
167+
relative_error = abs_diff / (torch.abs(original_output) + 1e-8)
168+
mean_relative_error = relative_error.mean().item()
169+
170+
assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"
171+
172+
173+
if __name__ == "__main__":
174+
pytest.main([__file__])

0 commit comments

Comments
 (0)