Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR implements comprehensive CPU overhead optimizations across Python and C++ layers of TransformerEngine. Python optimizations:
C++ optimizations:
Key improvements from previous review rounds:
Issue found:
Confidence Score: 4/5
Important Files Changed
Last reviewed commit: 73e4d1d |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
|
/te-ci pytorch |
…ormerEngine into cpu_fp8_optimizations
|
/te-ci pytorch |
|
/te-ci pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci pytorch |
Additional Comments (1)
when copying from one result: cached add after line 1019: self._dtype = tensor.dtype |
Description
CPU overhead optimizations
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Python Optimizations
C++ Optimizations
Checklist: