Skip to content

Commit 57895c6

Browse files
authored
Model fp8 initialization sync (#38)
* model fp8 initialization sync Signed-off-by: Ananth Subramaniam <[email protected]> * right imports Signed-off-by: Ananth Subramaniam <[email protected]> --------- Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent ff1bb40 commit 57895c6

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

src/megatron/hub/models/__init__.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TorchFullyShardedDataParallel,
2121
)
2222
from megatron.core.enums import ModelType
23-
from megatron.core.fp8_utils import is_float8tensor
23+
from megatron.core.fp8_utils import correct_amax_history_if_needed
2424
from megatron.core.transformer.module import Float16Module, MegatronModule
2525

2626
from megatron.hub.models.gpt import GPTConfig
@@ -178,18 +178,14 @@ def get_distributed_model(
178178
# Fp16 conversion.
179179
if model_config.fp16 or model_config.bf16:
180180
model = [Float16Module(model_config, model_module) for model_module in model]
181-
# The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's
182-
# Float8Tensor, which will write an unwanted value (amax calculated from the current fp8
183-
# param) to its amax_history. The following logic will correct the amax_history back.
184-
for model_module in model:
185-
for param in model_module.parameters():
186-
if is_float8tensor(param) and param._fp8_meta is not None:
187-
fp8_meta = param._fp8_meta["scaling_fwd"]
188-
fp8_meta_index = param._fp8_meta_index
189-
if hasattr(param, "get_high_precision_init_val"):
190-
fp8_meta.amax_history[0][fp8_meta_index].copy_(param.get_high_precision_init_val().abs().max())
191-
else:
192-
fp8_meta.amax_history[0][fp8_meta_index] = 0
181+
182+
# Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace
183+
# copy of TE's Float8Tensor, which will write an unwanted value (amax calculated
184+
# from the current fp8 param) to its amax_history. The below function will correct
185+
# the amax_history back.
186+
# After TE2.x: Below function is an empty function and does nothing.
187+
correct_amax_history_if_needed(model)
188+
193189
if wrap_with_ddp:
194190
if use_torch_fsdp2:
195191
DP = TorchFullyShardedDataParallel

0 commit comments

Comments
 (0)