Skip to content

Conversation

@ksivaman
Copy link
Member

Description

#2388 introduced the GroupedTensor class in the core library. This PR partly integrates this functionality to the PyTorch bindings.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Expose a python GroupedTensor class.
  • Integrate GroupedTensor into GroupedLinear such that the parameters are contiguous.
  • Expose a C++ grouped_quantize API to python similar to the split_quantize which returns a quantized GroupedTensor that can be directly consumed by the GEMMs ([common] Add support for cuBLASLt GEMM for GroupedTensor #2502).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ksivaman ksivaman marked this pull request as draft January 15, 2026 14:58
@ksivaman ksivaman requested a review from ptrendx January 15, 2026 14:58
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 15, 2026

Greptile Summary

This PR integrates the GroupedTensor class from #2388 into PyTorch bindings, enabling contiguous memory storage for multiple weight tensors in GroupedLinear.

Key Changes

  • New GroupedTensor class (918 lines): Stores multiple tensors with different shapes in contiguous memory, supporting all quantization recipes (FP8, MXFP8, NVFP4, block scaling)
  • GroupedLinear integration: Added make_grouped_weights() method that converts individual weight parameters into views of a single contiguous GroupedTensor storage
  • Recipe API refactoring: Changed type-checking methods from instance methods to classmethods (isinstanceissubclass) to align with _get_compatible_recipe() returning class types
  • Quantizer enhancements: Added get_columnwise_shape() and get_scale_shape() methods for proper memory layout calculations
  • Comprehensive tests: 430-line test suite verifying contiguous memory layout and quantization correctness across all recipes

Implementation Notes

The implementation allocates all weight data in a single contiguous buffer, then creates individual parameter views that share the underlying storage. This improves memory locality and enables future optimizations like grouped GEMMs (#2502).

Confidence Score: 4/5

  • This PR is safe to merge with minor caveats that should be verified through testing
  • The implementation is well-designed and comprehensive, with extensive tests covering all quantization recipes. The core GroupedTensor logic is sound, and the integration into GroupedLinear follows established patterns. However, there are two acknowledged areas needing verification: (1) the copy operation from regular tensors to quantized tensors in make_grouped_weights() has a TODO comment indicating uncertainty about correctness across all recipes, and (2) the assumption that all quantizers in a group are "effectively the same" is not strongly enforced. The recipe API change from instance methods to classmethods is correct but represents a subtle behavioral change that could affect code calling these methods on instances.
  • Pay close attention to transformer_engine/pytorch/module/grouped_linear.py (copy operation in make_grouped_weights) and transformer_engine/common/recipe/__init__.py (instance method to classmethod change)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/grouped_tensor.py New 918-line file implementing GroupedTensor class for contiguous storage of multiple tensors with different shapes. Comprehensive implementation with quantization support for FP8, MXFP8, NVFP4, and block scaling recipes.
transformer_engine/pytorch/module/grouped_linear.py Added make_grouped_weights() method to convert weight parameters into contiguous GroupedTensor storage. Weights are copied and re-registered to share underlying storage.
transformer_engine/common/recipe/init.py Changed recipe type-checking methods from instance methods to classmethods, using issubclass() instead of isinstance(). This aligns with _get_compatible_recipe() returning class types.
tests/pytorch/test_grouped_tensor.py New comprehensive test file with 430 lines covering GroupedTensor construction, splitting, quantization for all supported recipes, and verification of contiguous memory layout.

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear
    participant GroupedTensor
    participant Quantizer
    participant Storage

    Note over User,Storage: Initialization Phase
    User->>GroupedLinear: __init__(num_gemms, in_features, out_features)
    GroupedLinear->>GroupedLinear: register_parameter(weight0...weightN)
    GroupedLinear->>GroupedLinear: reset_parameters()
    GroupedLinear->>GroupedLinear: make_grouped_weights()
    
    Note over GroupedLinear,Storage: Weight Consolidation
    GroupedLinear->>Quantizer: _get_weight_quantizers()
    Quantizer-->>GroupedLinear: [quantizer0...quantizerN]
    GroupedLinear->>GroupedTensor: make_grouped_tensor(num_tensors, shapes, quantizers)
    
    Note over GroupedTensor,Storage: Allocate Contiguous Storage
    GroupedTensor->>GroupedTensor: analyze shape patterns
    GroupedTensor->>GroupedTensor: calculate logical_shape, offsets
    GroupedTensor->>Storage: allocate contiguous buffers (data, scale_inv, etc)
    GroupedTensor->>GroupedTensor: split_into_quantized_tensors()
    GroupedTensor-->>GroupedLinear: grouped_weights with quantized_tensors
    
    Note over GroupedLinear: Copy & Re-register Weights
    loop for each weight i
        GroupedLinear->>GroupedTensor: quantized_tensors[i].copy_(weights[i])
        GroupedLinear->>GroupedLinear: register_parameter(weightI, quantized_tensors[i])
    end
    
    Note over User,Storage: Forward Pass
    User->>GroupedLinear: forward(inp, m_splits)
    GroupedLinear->>GroupedLinear: _get_weight_tensors()
    GroupedLinear->>GroupedLinear: prepare quantizers
    GroupedLinear->>GroupedLinear: _GroupedLinear.apply()
    Note over GroupedLinear: All weights share contiguous storage
    GroupedLinear->>GroupedLinear: general_grouped_gemm(weights, inputs)
    GroupedLinear-->>User: output tensor
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +771 to +774
# TODO(ksivamani): Verify correctness of copy for all recipes.
with torch.no_grad():
for i in range(self.num_gemms):
grouped_weights.quantized_tensors[i].copy_(weights[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: check that the copy operation works correctly for all quantization recipes (FP8, MXFP8, NVFP4, block scaling). the TODO comment on line 771 acknowledges this needs verification.

Comment on lines +382 to +386
# TODO(ksivaman): (Do we need multiple quantizers?)
# Current implementation assumes all tensors have the different quantizers.
# instances but effectively the same quantizer.
rowwise_usage = quantizers[0].rowwise_usage if not no_quantization else True
columnwise_usage = quantizers[0].columnwise_usage if not no_quantization else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: check that all quantizers in the group are compatible. the comment acknowledges uncertainty about whether multiple quantizers are needed, but the implementation assumes they're "effectively the same" - mixed quantization schemes could cause issues.

Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Zhongbo Zhu <[email protected]>
@ksivaman ksivaman force-pushed the grouped_tensor_python branch from 2b7ea40 to 40c619e Compare January 16, 2026 07:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant