|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Tests for PyTorch Geometric quantization plugin.""" |
| 17 | + |
| 18 | +import pytest |
| 19 | +import torch |
| 20 | +import torch.nn as nn |
| 21 | + |
| 22 | +import modelopt.torch.quantization as mtq |
| 23 | + |
| 24 | +torch_geometric = pytest.importorskip( |
| 25 | + "torch_geometric", reason="PyTorch Geometric is not installed" |
| 26 | +) |
| 27 | +from torch_geometric.nn import GATConv, GCNConv, SAGEConv, TransformerConv |
| 28 | + |
| 29 | + |
| 30 | +class TestPyTorchGeometricPlugin: |
| 31 | + """Test PyTorch Geometric quantization support.""" |
| 32 | + |
| 33 | + @pytest.fixture |
| 34 | + def device(self): |
| 35 | + """Get test device.""" |
| 36 | + return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 37 | + |
| 38 | + def create_graph_data(self, batch_size=2, num_nodes=20, in_channels=16, device="cpu"): |
| 39 | + """Create sample graph data for testing.""" |
| 40 | + x = torch.randn(batch_size * num_nodes, in_channels, device=device) |
| 41 | + # Create batch assignment |
| 42 | + batch = torch.cat([torch.full((num_nodes,), i, device=device) for i in range(batch_size)]) |
| 43 | + |
| 44 | + # Create edge indices for each graph |
| 45 | + edge_list = [] |
| 46 | + offset = 0 |
| 47 | + for _ in range(batch_size): |
| 48 | + # Create random edges within each graph |
| 49 | + src = torch.randint(0, num_nodes, (50,), device=device) + offset |
| 50 | + dst = torch.randint(0, num_nodes, (50,), device=device) + offset |
| 51 | + edge_list.append(torch.stack([src, dst])) |
| 52 | + offset += num_nodes |
| 53 | + |
| 54 | + edge_index = torch.cat(edge_list, dim=1) |
| 55 | + edge_attr = torch.randn(edge_index.size(1), 32, device=device) |
| 56 | + |
| 57 | + return x, edge_index, edge_attr, batch |
| 58 | + |
| 59 | + def test_gat_conv_quantization(self, device): |
| 60 | + """Test GATConv layer quantization.""" |
| 61 | + |
| 62 | + class GATModel(nn.Module): |
| 63 | + def __init__(self): |
| 64 | + super().__init__() |
| 65 | + self.gat1 = GATConv(16, 64, heads=4, edge_dim=32) |
| 66 | + self.gat2 = GATConv(256, 32, heads=1, edge_dim=32) |
| 67 | + |
| 68 | + def forward(self, x, edge_index, edge_attr): |
| 69 | + x = torch.relu(self.gat1(x, edge_index, edge_attr)) |
| 70 | + return self.gat2(x, edge_index, edge_attr) |
| 71 | + |
| 72 | + model = GATModel().to(device) |
| 73 | + |
| 74 | + # Calibration function |
| 75 | + def calibrate(m): |
| 76 | + m.eval() |
| 77 | + with torch.no_grad(): |
| 78 | + for _ in range(5): |
| 79 | + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) |
| 80 | + _ = m(x, edge_index, edge_attr) |
| 81 | + |
| 82 | + # Quantize model |
| 83 | + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) |
| 84 | + |
| 85 | + # Verify quantization |
| 86 | + quantizer_count = sum( |
| 87 | + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() |
| 88 | + ) |
| 89 | + assert quantizer_count > 0, "No quantizers were inserted" |
| 90 | + |
| 91 | + # Test forward pass |
| 92 | + x, edge_index, edge_attr, _ = self.create_graph_data(device=device) |
| 93 | + with torch.no_grad(): |
| 94 | + output = quantized(x, edge_index, edge_attr) |
| 95 | + assert output is not None |
| 96 | + |
| 97 | + def test_multiple_layer_types(self, device): |
| 98 | + """Test quantization of multiple PyG layer types.""" |
| 99 | + |
| 100 | + class MultiLayerGNN(nn.Module): |
| 101 | + def __init__(self): |
| 102 | + super().__init__() |
| 103 | + self.gcn = GCNConv(16, 32) |
| 104 | + self.sage = SAGEConv(32, 64) |
| 105 | + self.transformer = TransformerConv(64, 32, heads=2) |
| 106 | + |
| 107 | + def forward(self, x, edge_index): |
| 108 | + x = torch.relu(self.gcn(x, edge_index)) |
| 109 | + x = torch.relu(self.sage(x, edge_index)) |
| 110 | + return self.transformer(x, edge_index) |
| 111 | + |
| 112 | + model = MultiLayerGNN().to(device) |
| 113 | + |
| 114 | + # Calibration |
| 115 | + def calibrate(m): |
| 116 | + m.eval() |
| 117 | + with torch.no_grad(): |
| 118 | + for _ in range(3): |
| 119 | + x = torch.randn(50, 16, device=device) |
| 120 | + edge_index = torch.randint(0, 50, (2, 100), device=device) |
| 121 | + _ = m(x, edge_index) |
| 122 | + |
| 123 | + # Quantize |
| 124 | + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) |
| 125 | + |
| 126 | + # Check that PyG Linear layers were quantized |
| 127 | + pyg_linear_count = 0 |
| 128 | + for name, module in model.named_modules(): |
| 129 | + if hasattr(module, "lin") and "torch_geometric" in str(type(module.lin)): |
| 130 | + pyg_linear_count += 1 |
| 131 | + |
| 132 | + quantizer_count = sum( |
| 133 | + 1 for _, m in quantized.named_modules() if "quantizer" in type(m).__name__.lower() |
| 134 | + ) |
| 135 | + |
| 136 | + # Each PyG linear should have at least 2 quantizers (input, weight) |
| 137 | + assert quantizer_count >= pyg_linear_count * 2, ( |
| 138 | + f"Expected at least {pyg_linear_count * 2} quantizers, got {quantizer_count}" |
| 139 | + ) |
| 140 | + |
| 141 | + def test_quantization_accuracy(self, device): |
| 142 | + """Test that quantization maintains reasonable accuracy.""" |
| 143 | + model = GATConv(16, 32, heads=2, edge_dim=16).to(device) |
| 144 | + |
| 145 | + # Create test data |
| 146 | + x, edge_index, edge_attr, _ = self.create_graph_data( |
| 147 | + batch_size=1, in_channels=16, device=device |
| 148 | + ) |
| 149 | + edge_attr = edge_attr[:, :16] # Match edge_dim |
| 150 | + |
| 151 | + # Get original output |
| 152 | + model.eval() |
| 153 | + with torch.no_grad(): |
| 154 | + original_output = model(x, edge_index, edge_attr) |
| 155 | + |
| 156 | + # Calibration |
| 157 | + def calibrate(m): |
| 158 | + m.eval() |
| 159 | + with torch.no_grad(): |
| 160 | + _ = m(x, edge_index, edge_attr) |
| 161 | + |
| 162 | + # Quantize |
| 163 | + quantized = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate) |
| 164 | + |
| 165 | + # Get quantized output |
| 166 | + with torch.no_grad(): |
| 167 | + quantized_output = quantized(x, edge_index, edge_attr) |
| 168 | + |
| 169 | + # Check relative error |
| 170 | + abs_diff = torch.abs(original_output - quantized_output) |
| 171 | + relative_error = abs_diff / (torch.abs(original_output) + 1e-8) |
| 172 | + mean_relative_error = relative_error.mean().item() |
| 173 | + |
| 174 | + assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}" |
| 175 | + |
| 176 | + |
| 177 | +if __name__ == "__main__": |
| 178 | + pytest.main([__file__]) |
0 commit comments