Skip to content

Conversation

@i-riyad
Copy link
Contributor

@i-riyad i-riyad commented Nov 3, 2025

What does this PR do?

Type of change: New feature

Overview: Support quantization of PyTorch Geometric

# Add a code snippet demonstrating how to use this

Testing

python -m pytest tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py -v

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

@i-riyad i-riyad requested a review from a team as a code owner November 3, 2025 21:09
@i-riyad i-riyad requested a review from sychen52 November 3, 2025 21:09
@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from ca267f1 to 5386681 Compare November 3, 2025 21:09
@codecov
Copy link

codecov bot commented Nov 3, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.43%. Comparing base (72f23dc) to head (5386681).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #494   +/-   ##
=======================================
  Coverage   73.43%   73.43%           
=======================================
  Files         180      180           
  Lines       18146    18146           
=======================================
  Hits        13326    13326           
  Misses       4820     4820           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from 5386681 to 0bab18e Compare November 4, 2025 17:24
@i-riyad i-riyad requested a review from a team as a code owner November 4, 2025 17:24
@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from 0bab18e to 17f5636 Compare November 4, 2025 17:26
CHANGELOG.rst Outdated
Model Optimizer Changelog (Linux)
=================================

0.40 (2025-12-09)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exact date might get pushed

Suggested change
0.40 (2025-12-09)
0.40 (2025-12-xx)

@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from 17f5636 to 243af33 Compare November 4, 2025 17:27
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on a high level. Approving as codeowner

@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from 243af33 to 2f94190 Compare November 4, 2025 17:28
@i-riyad i-riyad force-pushed the rislam/pytorch-geometric-quant branch from 2f94190 to b65e0b3 Compare November 4, 2025 19:09
Comment on lines +61 to +86
def forward(self, input, *args, **kwargs):
"""Forward pass with quantization.
Args:
input: Input tensor to the linear layer
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Quantized output tensor
"""
# Quantize input activations
input_q = self.input_quantizer(input)

# Quantize weights
weight_q = self.weight_quantizer(self.weight)

# Perform linear operation
output = torch.nn.functional.linear(
input_q,
weight_q,
self.bias if hasattr(self, "bias") and self.bias is not None else None,
)

# Quantize output (typically disabled by default)
return self.output_quantizer(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This might not be needed

Suggested change
def forward(self, input, *args, **kwargs):
"""Forward pass with quantization.
Args:
input: Input tensor to the linear layer
*args: Additional positional arguments
**kwargs: Additional keyword arguments
Returns:
Quantized output tensor
"""
# Quantize input activations
input_q = self.input_quantizer(input)
# Quantize weights
weight_q = self.weight_quantizer(self.weight)
# Perform linear operation
output = torch.nn.functional.linear(
input_q,
weight_q,
self.bias if hasattr(self, "bias") and self.bias is not None else None,
)
# Quantize output (typically disabled by default)
return self.output_quantizer(output)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, see the definition for ConvNd layers:

@QuantModuleRegistry.register({nn.Conv1d: "nn.Conv1d"})

The forward path should work (inherited from QuantLinearConvBase)

Comment on lines +173 to +174
if __name__ == "__main__":
pytest.main([__file__])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

For debugging, I directly run the unitests using VSCode/Cursor test explorer.

Suggested change
if __name__ == "__main__":
pytest.main([__file__])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants