|
20 | 20 | TorchFullyShardedDataParallel, |
21 | 21 | ) |
22 | 22 | from megatron.core.enums import ModelType |
23 | | -from megatron.core.fp8_utils import is_float8tensor |
| 23 | +from megatron.core.fp8_utils import correct_amax_history_if_needed |
24 | 24 | from megatron.core.transformer.module import Float16Module, MegatronModule |
25 | 25 |
|
26 | 26 | from megatron.hub.models.gpt import GPTConfig |
@@ -178,18 +178,14 @@ def get_distributed_model( |
178 | 178 | # Fp16 conversion. |
179 | 179 | if model_config.fp16 or model_config.bf16: |
180 | 180 | model = [Float16Module(model_config, model_module) for model_module in model] |
181 | | - # The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's |
182 | | - # Float8Tensor, which will write an unwanted value (amax calculated from the current fp8 |
183 | | - # param) to its amax_history. The following logic will correct the amax_history back. |
184 | | - for model_module in model: |
185 | | - for param in model_module.parameters(): |
186 | | - if is_float8tensor(param) and param._fp8_meta is not None: |
187 | | - fp8_meta = param._fp8_meta["scaling_fwd"] |
188 | | - fp8_meta_index = param._fp8_meta_index |
189 | | - if hasattr(param, "get_high_precision_init_val"): |
190 | | - fp8_meta.amax_history[0][fp8_meta_index].copy_(param.get_high_precision_init_val().abs().max()) |
191 | | - else: |
192 | | - fp8_meta.amax_history[0][fp8_meta_index] = 0 |
| 181 | + |
| 182 | + # Before TE2.x: The model_module.bfloat16()/model_module.half() above will call the inplace |
| 183 | + # copy of TE's Float8Tensor, which will write an unwanted value (amax calculated |
| 184 | + # from the current fp8 param) to its amax_history. The below function will correct |
| 185 | + # the amax_history back. |
| 186 | + # After TE2.x: Below function is an empty function and does nothing. |
| 187 | + correct_amax_history_if_needed(model) |
| 188 | + |
193 | 189 | if wrap_with_ddp: |
194 | 190 | if use_torch_fsdp2: |
195 | 191 | DP = TorchFullyShardedDataParallel |
|
0 commit comments