Skip to content

Commit ca267f1

Browse files
committed
pytorch geometric quantization support
Signed-off-by: Riyad Islam <[email protected]>
1 parent 72f23dc commit ca267f1

File tree

3 files changed

+283
-0
lines changed

3 files changed

+283
-0
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- :meth:`huggingface<modelopt.torch.quantization.plugins.huggingface>`
2626
- :meth:`megatron<modelopt.torch.quantization.plugins.megatron>`
2727
- :meth:`peft<modelopt.torch.quantization.plugins.peft>`
28+
- :meth:`pytorch_geometric<modelopt.torch.quantization.plugins.pytorch_geometric>`
2829
- :meth:`transformer_engine<modelopt.torch.quantization.plugins.transformer_engine>`
2930
"""
3031

@@ -57,6 +58,9 @@
5758
with import_plugin("peft"):
5859
from .peft import *
5960

61+
with import_plugin("torch_geometric"):
62+
from .pytorch_geometric import *
63+
6064
with import_plugin("transformer_engine"):
6165
from .transformer_engine import *
6266

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""PyTorch Geometric quantization plugin.
17+
18+
This plugin enables quantization support for PyTorch Geometric (PyG) layers by registering
19+
PyG's custom Linear layer with ModelOpt's quantization registry.
20+
21+
Example:
22+
>>> import modelopt.torch.quantization as mtq
23+
>>> from torch_geometric.nn import GATConv
24+
>>>
25+
>>> # Create a model with PyG layers
26+
>>> class GATModel(nn.Module):
27+
... def __init__(self):
28+
... super().__init__()
29+
... self.gat1 = GATConv(10, 64, heads=4)
30+
... self.gat2 = GATConv(64 * 4, 32, heads=1)
31+
>>> model = GATModel()
32+
>>> # PyG layers are now automatically quantizable!
33+
>>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
34+
"""
35+
36+
import torch
37+
38+
from modelopt.torch.quantization.nn.modules.quant_module import (
39+
QuantLinearConvBase,
40+
QuantModuleRegistry,
41+
)
42+
from modelopt.torch.quantization.tensor_quant import QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
43+
from modelopt.torch.utils import import_plugin
44+
45+
__all__ = []
46+
47+
48+
class QuantPyGLinear(QuantLinearConvBase):
49+
"""Quantized version of PyTorch Geometric's Linear layer.
50+
51+
PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
52+
torch.nn.Linear but has a different API (in_channels/out_channels instead of
53+
in_features/out_features). This class enables quantization of PyG Linear layers.
54+
55+
Note:
56+
Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
57+
internally use PyG Linear layers, so registering this class enables quantization
58+
for a wide range of graph neural network layers.
59+
"""
60+
61+
default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
62+
63+
def forward(self, input, *args, **kwargs):
64+
"""Forward pass with quantization.
65+
66+
Args:
67+
input: Input tensor to the linear layer
68+
*args: Additional positional arguments
69+
**kwargs: Additional keyword arguments
70+
71+
Returns:
72+
Quantized output tensor
73+
"""
74+
# Quantize input activations
75+
input_q = self.input_quantizer(input)
76+
77+
# Quantize weights
78+
weight_q = self.weight_quantizer(self.weight)
79+
80+
# Perform linear operation
81+
output = torch.nn.functional.linear(
82+
input_q,
83+
weight_q,
84+
self.bias if hasattr(self, "bias") and self.bias is not None else None,
85+
)
86+
87+
# Quantize output (typically disabled by default)
88+
return self.output_quantizer(output)
89+
90+
91+
# Register only if torch_geometric is available
92+
with import_plugin("torch_geometric"):
93+
from torch_geometric.nn.dense.linear import Linear as PyGLinear
94+
95+
# Register the quantized version
96+
QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(
97+
QuantPyGLinear
98+
)
99+
100+
# Export QuantPyGLinear only if registration succeeded
101+
__all__.append("QuantPyGLinear")
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for PyTorch Geometric quantization plugin."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
22+
import modelopt.torch.quantization as mtq
23+
24+
torch_geometric = pytest.importorskip(
25+
"torch_geometric", reason="PyTorch Geometric is not installed"
26+
)
27+
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv
28+
29+
30+
class TestPyTorchGeometricPlugin:
31+
"""Test PyTorch Geometric quantization support."""
32+
33+
@pytest.fixture
34+
def device(self):
35+
"""Get test device."""
36+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
37+
38+
def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"):
39+
"""Create sample graph data for testing."""
40+
x = torch.randn(batch_size * num_nodes, in_channels, device=device)
41+
# Create batch assignment
42+
batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)])
43+
44+
# Create edge indices for each graph
45+
edge_list = []
46+
offset = 0
47+
for _ in range(batch_size):
48+
# Create random edges within each graph
49+
src = torch.randint(0, num_nodes, (50,), device=device) + offset
50+
dst = torch.randint(0, num_nodes, (50,), device=device) + offset
51+
edge_list.append(torch.stack([src, dst]))
52+
offset += num_nodes
53+
54+
edge_index = torch.cat(edge_list, dim=1)
55+
edge_attr = torch.randn(edge_index.size(1), 32, device=device)
56+
57+
return x, edge_index, edge_attr, batch
58+
59+
def test_gat_conv_quantization(self, device):
60+
"""Test GATConv layer quantization."""
61+
62+
class GATModel(nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.gat1 = GATConv(16, 64, heads=4, edge_dim=32)
66+
self.gat2 = GATConv(256, 32, heads=1, edge_dim=32)
67+
68+
def forward(self, x, edge_index, edge_attr):
69+
x = torch.relu(self.gat1(x, edge_index, edge_attr))
70+
return self.gat2(x, edge_index, edge_attr)
71+
72+
model = GATModel().to(device)
73+
74+
# Calibration function
75+
def calibrate(m):
76+
m.eval()
77+
with torch.no_grad():
78+
for _ in range(5):
79+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
80+
_ = m(x, edge_index, edge_attr)
81+
82+
# Quantize model
83+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
84+
85+
# Verify quantization
86+
quantizer_count = sum(
87+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
88+
)
89+
assert quantizer_count > 0, "No quantizers were inserted"
90+
91+
# Test forward pass
92+
x, edge_index, edge_attr, _ = self.create_graph_data(device=device)
93+
with torch.no_grad():
94+
output = quantized(x, edge_index, edge_attr)
95+
assert output is not None
96+
97+
def test_multiple_layer_types(self, device):
98+
"""Test quantization of multiple PyG layer types."""
99+
100+
class MultiLayerGNN(nn.Module):
101+
def __init__(self):
102+
super().__init__()
103+
self.gcn = GCNConv(16, 32)
104+
self.sage = SAGEConv(32, 64)
105+
self.transformer = TransformerConv(64, 32, heads=2)
106+
107+
def forward(self, x, edge_index):
108+
x = torch.relu(self.gcn(x, edge_index))
109+
x = torch.relu(self.sage(x, edge_index))
110+
return self.transformer(x, edge_index)
111+
112+
model = MultiLayerGNN().to(device)
113+
114+
# Calibration
115+
def calibrate(m):
116+
m.eval()
117+
with torch.no_grad():
118+
for _ in range(3):
119+
x = torch.randn(50, 16, device=device)
120+
edge_index = torch.randint(0, 50, (2, 100), device=device)
121+
_ = m(x, edge_index)
122+
123+
# Quantize
124+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
125+
126+
# Check that PyG Linear layers were quantized
127+
pyg_linear_count = 0
128+
for name, module in model.named_modules():
129+
if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)):
130+
pyg_linear_count += 1
131+
132+
quantizer_count = sum(
133+
1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower()
134+
)
135+
136+
# Each PyG linear should have at least 2 quantizers (input, weight)
137+
assert quantizer_count >= pyg_linear_count * 2, (
138+
f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}"
139+
)
140+
141+
def test_quantization_accuracy(self, device):
142+
"""Test that quantization maintains reasonable accuracy."""
143+
model = GATConv(16, 32, heads=2, edge_dim=16).to(device)
144+
145+
# Create test data
146+
x, edge_index, edge_attr, _ = self.create_graph_data(
147+
batch_size=1, in_channels=16, device=device
148+
)
149+
edge_attr = edge_attr[:, :16] # Match edge_dim
150+
151+
# Get original output
152+
model.eval()
153+
with torch.no_grad():
154+
original_output = model(x, edge_index, edge_attr)
155+
156+
# Calibration
157+
def calibrate(m):
158+
m.eval()
159+
with torch.no_grad():
160+
_ = m(x, edge_index, edge_attr)
161+
162+
# Quantize
163+
quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
164+
165+
# Get quantized output
166+
with torch.no_grad():
167+
quantized_output = quantized(x, edge_index, edge_attr)
168+
169+
# Check relative error
170+
abs_diff = torch.abs(original_output - quantized_output)
171+
relative_error = abs_diff / (torch.abs(original_output) + 1e-8)
172+
mean_relative_error = relative_error.mean().item()
173+
174+
assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"
175+
176+
177+
if __name__ == "__main__":
178+
pytest.main([__file__])

0 commit comments

Comments
 (0)