|
| 1 | +# This file was modified for portability to AMDGPU |
| 2 | +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. |
1 | 3 | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 4 | # |
3 | 5 | # See LICENSE for license information. |
@@ -282,8 +284,8 @@ def forward( |
282 | 284 | inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) |
283 | 285 | saved_inputmat = inputmat |
284 | 286 |
|
285 | | - # Weight with column-wise usage is needed for dgrad GEMM. |
286 | | - if inp.requires_grad: |
| 287 | + # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. |
| 288 | + if inp.requires_grad and keep_fp8_weight_transpose_cache: |
287 | 289 | if isinstance(weightmat, QuantizedTensor): |
288 | 290 | weightmat.update_usage(columnwise_usage=True) |
289 | 291 |
|
@@ -828,10 +830,20 @@ class Linear(TransformerEngineBaseModule): |
828 | 830 | it controls the type used to allocate the initial parameters. Useful when |
829 | 831 | the model is trained with lower precision and the original FP32 parameters |
830 | 832 | would not fit in GPU memory. |
831 | | - keep_fp8_weight_transpose_cache: bool, default = 'True' |
832 | | - if set to `False`, it will not cache fp8 weight buffer instead of |
833 | | - recomputing fp8 weight transpose. Recommend set to `False` when |
834 | | - enable FSDP parallel. |
| 833 | + keep_fp8_weight_transpose_cache: bool, default = True |
| 834 | + Controls whether to cache the FP8 weight transpose buffer during training. |
| 835 | +
|
| 836 | + - If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation, |
| 837 | + which can improve performance but significantly increases memory usage. |
| 838 | + - If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed. |
| 839 | + This reduces memory consumption, especially during checkpoint loading and runtime. |
| 840 | +
|
| 841 | + **Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training. |
| 842 | + Caching FP8 weight transposes can double memory usage for modules such as `Linear`, |
| 843 | + `LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and |
| 844 | + reduced efficiency of PyTorch's caching allocator. |
| 845 | +
|
| 846 | + Use this setting to balance memory usage and performance based on your training configuration. |
835 | 847 |
|
836 | 848 | """ |
837 | 849 |
|
|
0 commit comments