Skip to content

Commit 867687a

Browse files
sudhu2kSudharshan Govindan
andauthored
Release v2.2 cherrypicks and bugfixes megatron lm (#362)
* Ensure weight transpose is valid for FP8 training (#1596) (#276) * Update usage of weightmat before saving for backward * Added keep_fp8_weight_transpose_cache checks while updating transpose in fwd pass (#298) * Added keep_fp8_weight_transpose_cache checks while updating transpose * Added unittest for the fix * Added comment for the unit test * Fixed comment * Reverted test for single iteration, added assert statements to check for transpose cache, Modified docstring * Fixed test_numerics spacing * Added HIP Guards * Addressed PR Comments, and moved assertion statements under fp8 check * Reverting assertion to fix the dev ticket * Removed spacing --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com> * Bug fix for get_fp8_metas * Added keep_fp8_transpose_cache fix for base.py * added _fp8_metas check for None * Added comment --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com>
1 parent bb087d0 commit 867687a

File tree

6 files changed

+73
-32
lines changed

6 files changed

+73
-32
lines changed

tests/pytorch/test_numerics.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,13 +1274,21 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
12741274
@pytest.mark.parametrize("bs", batch_sizes)
12751275
@pytest.mark.parametrize("model", ["small"])
12761276
@pytest.mark.parametrize("fp8_model_params", all_boolean)
1277-
def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model_params):
1277+
@pytest.mark.parametrize("module_str", ["linear", "layernorm_mlp", "layernorm_linear"])
1278+
def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model_params, module_str):
12781279
reset_rng_states()
12791280
FP8GlobalStateManager.reset()
12801281

1282+
if module_str == "linear":
1283+
module = Linear
1284+
elif module_str == "layernorm_mlp":
1285+
module = LayerNormMLP
1286+
elif module_str == "layernorm_linear":
1287+
module = LayerNormLinear
1288+
12811289
config = model_configs[model]
12821290
with fp8_model_init(enabled=fp8_model_params):
1283-
linear = Linear(
1291+
layer = module(
12841292
config.hidden_size,
12851293
4 * config.hidden_size,
12861294
bias=True,
@@ -1289,20 +1297,17 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model
12891297
keep_fp8_weight_transpose_cache=False
12901298
).eval()
12911299

1292-
ref_linear = Linear(
1300+
reset_rng_states()
1301+
ref_layer = module(
12931302
config.hidden_size,
12941303
4 * config.hidden_size,
12951304
bias=True,
12961305
params_dtype=dtype,
12971306
device="cuda",
12981307
).eval()
12991308

1300-
# Share params
1301-
with torch.no_grad():
1302-
ref_linear.weight = Parameter(linear.weight.clone())
1303-
ref_linear.bias = Parameter(linear.bias.clone())
1304-
outputs = _test_granular_accuracy_with_fp8(linear, bs, dtype, config)
1305-
ref_outputs = _test_granular_accuracy_with_fp8(ref_linear, bs, dtype, config)
1309+
outputs = _test_granular_accuracy_with_fp8(layer, bs, dtype, config)
1310+
ref_outputs = _test_granular_accuracy_with_fp8(ref_layer, bs, dtype, config)
13061311

13071312
# Check output.
13081313
for te_output, torch_output in zip(outputs, ref_outputs):

transformer_engine/pytorch/module/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def get_weight_workspace(
10091009
if update_workspace and quantizer is not None:
10101010
tensor.update_usage(
10111011
rowwise_usage=quantizer.rowwise_usage,
1012-
columnwise_usage=quantizer.columnwise_usage,
1012+
columnwise_usage=quantizer.columnwise_usage and create_transpose_cache,
10131013
)
10141014
return tensor
10151015

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,10 @@ def forward(
339339
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
340340
ln_out.update_usage(rowwise_usage=False)
341341

342-
# Weight with column-wise usage is needed for dgrad GEMM.
343-
if isinstance(weightmat, QuantizedTensor):
344-
weightmat.update_usage(columnwise_usage=True)
342+
# Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
343+
if inp.requires_grad and keep_fp8_weight_transpose_cache:
344+
if isinstance(weightmat, QuantizedTensor):
345+
weightmat.update_usage(columnwise_usage=True)
345346

346347
if cpu_offloading:
347348
if fp8 and weightmat is not None:
@@ -975,10 +976,20 @@ class LayerNormLinear(TransformerEngineBaseModule):
975976
it controls the type used to allocate the initial parameters. Useful when
976977
the model is trained with lower precision and the original FP32 parameters
977978
would not fit in GPU memory.
978-
keep_fp8_weight_transpose_cache: bool, default = 'True'
979-
if set to `False`, it will not cache fp8 weight buffer instead of
980-
recomputing fp8 weight transpose. Recommend set to `False` when
981-
enable FSDP parallel.
979+
keep_fp8_weight_transpose_cache: bool, default = True
980+
Controls whether to cache the FP8 weight transpose buffer during training.
981+
982+
- If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation,
983+
which can improve performance but significantly increases memory usage.
984+
- If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed.
985+
This reduces memory consumption, especially during checkpoint loading and runtime.
986+
987+
**Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training.
988+
Caching FP8 weight transposes can double memory usage for modules such as `Linear`,
989+
`LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and
990+
reduced efficiency of PyTorch's caching allocator.
991+
992+
Use this setting to balance memory usage and performance based on your training configuration.
982993
983994
"""
984995

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,8 @@ def forward(
426426
extra_output=rs_out,
427427
)
428428

429-
# Weight with column-wise usage is needed for dgrad GEMM.
430-
if is_grad_enabled:
429+
# Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
430+
if is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache:
431431
if isinstance(fc1_weight_final, QuantizedTensor):
432432
fc1_weight_final.update_usage(columnwise_usage=True)
433433
if isinstance(fc2_weight_final, QuantizedTensor):
@@ -1219,10 +1219,20 @@ class LayerNormMLP(TransformerEngineBaseModule):
12191219
batch size per training step. Needed for JIT Warmup, a technique where jit
12201220
fused functions are warmed up before training to ensure same kernels are
12211221
used for forward propogation and activation recompute phase.
1222-
keep_fp8_weight_transpose_cache: bool, default = 'True'
1223-
if set to `False`, it will not cache fp8 weight buffer instead of
1224-
recomputing fp8 weight transpose. Recommend set to `False` when
1225-
enable FSDP parallel.
1222+
keep_fp8_weight_transpose_cache: bool, default = True
1223+
Controls whether to cache the FP8 weight transpose buffer during training.
1224+
1225+
- If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation,
1226+
which can improve performance but significantly increases memory usage.
1227+
- If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed.
1228+
This reduces memory consumption, especially during checkpoint loading and runtime.
1229+
1230+
**Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training.
1231+
Caching FP8 weight transposes can double memory usage for modules such as `Linear`,
1232+
`LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and
1233+
reduced efficiency of PyTorch's caching allocator.
1234+
1235+
Use this setting to balance memory usage and performance based on your training configuration.
12261236
12271237
"""
12281238

transformer_engine/pytorch/module/linear.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This file was modified for portability to AMDGPU
2+
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
13
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
24
#
35
# See LICENSE for license information.
@@ -282,8 +284,8 @@ def forward(
282284
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
283285
saved_inputmat = inputmat
284286

285-
# Weight with column-wise usage is needed for dgrad GEMM.
286-
if inp.requires_grad:
287+
# Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
288+
if inp.requires_grad and keep_fp8_weight_transpose_cache:
287289
if isinstance(weightmat, QuantizedTensor):
288290
weightmat.update_usage(columnwise_usage=True)
289291

@@ -828,10 +830,20 @@ class Linear(TransformerEngineBaseModule):
828830
it controls the type used to allocate the initial parameters. Useful when
829831
the model is trained with lower precision and the original FP32 parameters
830832
would not fit in GPU memory.
831-
keep_fp8_weight_transpose_cache: bool, default = 'True'
832-
if set to `False`, it will not cache fp8 weight buffer instead of
833-
recomputing fp8 weight transpose. Recommend set to `False` when
834-
enable FSDP parallel.
833+
keep_fp8_weight_transpose_cache: bool, default = True
834+
Controls whether to cache the FP8 weight transpose buffer during training.
835+
836+
- If set to `True` (default), the FP8 weight transpose buffer is cached to avoid recomputation,
837+
which can improve performance but significantly increases memory usage.
838+
- If set to `False`, the buffer is not cached and the FP8 weight transpose is recomputed as needed.
839+
This reduces memory consumption, especially during checkpoint loading and runtime.
840+
841+
**Recommendation**: Set this to `False` when using Fully Sharded Data Parallel (FSDP) training.
842+
Caching FP8 weight transposes can double memory usage for modules such as `Linear`,
843+
`LayerNormLinear`, and `LayerNormMLP`, which may lead to excessive memory pressure and
844+
reduced efficiency of PyTorch's caching allocator.
845+
846+
Use this setting to balance memory usage and performance based on your training configuration.
835847
836848
"""
837849

transformer_engine/pytorch/ops/op.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,10 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor:
547547
# Get state for a given FP8 tensor
548548
if self.num_quantizers(mode) == 0:
549549
continue
550-
fp8_meta = self.get_fp8_meta(mode)
550+
# Skip if op has no quantizer state
551+
if self._fp8_metas is None or self._fp8_metas.get(mode, None) is None:
552+
continue
553+
fp8_meta = self._fp8_metas.get(mode, None)
551554
state[mode] = {}
552555

553556
# Store tensors
@@ -603,7 +606,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
603606
continue
604607
if self.num_quantizers(mode) == 0:
605608
continue
606-
fp8_meta = self.get_fp8_meta(mode)
609+
fp8_meta = self._fp8_metas.get(mode, None)
607610
if fp8_meta is None:
608611
continue
609612

@@ -617,7 +620,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
617620
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
618621

619622
# Load tensors
620-
fp8_meta = self.get_fp8_meta(mode)
623+
fp8_meta = self._fp8_metas.get(mode, None)
621624
if "scaling_fwd" in fp8_meta:
622625
fp8_meta_fwd = fp8_meta["scaling_fwd"]
623626
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)

0 commit comments

Comments
 (0)