File tree Expand file tree Collapse file tree 5 files changed +19
-7
lines changed
src/megatron/bridge/training
tests/unit_tests/training Expand file tree Collapse file tree 5 files changed +19
-7
lines changed Original file line number Diff line number Diff line change @@ -441,6 +441,7 @@ config.rerun_state_machine = RerunStateMachineConfig(
441441 rerun_mode = " validate_results" , # or "report_determinism_stats" or "disabled"
442442 check_for_nan_in_loss = True ,
443443 check_for_spiky_loss = False ,
444+ spiky_loss_factor = 10.0 , # Adjust for your model architecture
444445 error_injection_rate = 0 , # For testing only
445446 error_injection_type = " transient_error" ,
446447)
@@ -453,6 +454,7 @@ config.rerun_state_machine = RerunStateMachineConfig(
453454| ` rerun_mode ` | ` str ` | ` "disabled" ` | Operating mode: ` "disabled" ` , ` "validate_results" ` , or ` "report_determinism_stats" ` |
454455| ` check_for_nan_in_loss ` | ` bool ` | ` True ` | Check for NaN values in loss |
455456| ` check_for_spiky_loss ` | ` bool ` | ` False ` | Check for unexpectedly large loss values |
457+ | ` spiky_loss_factor ` | ` float ` | ` 10.0 ` | Factor for spiky loss detection. Loss is flagged if it exceeds this multiple of max observed loss. Larger models may need higher values (e.g., 15-20 for 70B+). |
456458| ` error_injection_rate ` | ` int ` | ` 0 ` | Rate for injecting test errors (testing only) |
457459| ` error_injection_type ` | ` str ` | ` "transient_error" ` | Type of error to inject for testing |
458460
Original file line number Diff line number Diff line change @@ -223,6 +223,10 @@ class RerunStateMachineConfig:
223223 check_for_spiky_loss : bool = False
224224 """Check for spiky loss."""
225225
226+ spiky_loss_factor : float = 10.0
227+ """Factor for detecting spiky loss. A loss is considered spiky if it exceeds
228+ this multiple of the max observed loss over the sample window."""
229+
226230
227231@dataclass (kw_only = True )
228232class DataloaderConfig :
Original file line number Diff line number Diff line change @@ -222,6 +222,7 @@ def init_rerun_state(rerun_state_machine_config: RerunStateMachineConfig) -> Non
222222 RerunDiagnostic ,
223223 RerunErrorInjector ,
224224 RerunMode ,
225+ get_rerun_state_machine ,
225226 initialize_rerun_state_machine ,
226227 )
227228
@@ -242,6 +243,10 @@ def state_restore_func(state_dict):
242243 ),
243244 )
244245
246+ # Store config on the singleton for use in loss validation
247+ rsm = get_rerun_state_machine ()
248+ rsm .spiky_loss_factor = rerun_state_machine_config .spiky_loss_factor
249+
245250
246251def set_jit_fusion_options (model_config : GPTModelProvider | T5ModelProvider , micro_batch_size : int ) -> None :
247252 """Set PyTorch JIT layer fusion options and warmup JIT functions.
Original file line number Diff line number Diff line change 1919from megatron .core .rerun_state_machine import get_rerun_state_machine
2020
2121
22- SPIKY_LOSS_FACTOR : int = 10
22+ _DEFAULT_SPIKY_LOSS_FACTOR : float = 10.0
2323
2424
2525def create_masked_next_token_loss_function (
@@ -86,11 +86,12 @@ def masked_next_token_loss(
8686 )
8787 # Check for spiky loss
8888 if check_for_spiky_loss :
89+ spiky_loss_factor = getattr (rerun_state_machine , "spiky_loss_factor" , _DEFAULT_SPIKY_LOSS_FACTOR )
8990 rerun_state_machine .validate_result (
9091 result = loss ,
9192 rejection_func = partial (
9293 rerun_state_machine .is_unexpectedly_large ,
93- threshold = SPIKY_LOSS_FACTOR ,
94+ threshold = spiky_loss_factor ,
9495 context = "loss" ,
9596 ),
9697 message = "Spiky loss" ,
Original file line number Diff line number Diff line change 2020import torch
2121
2222from megatron .bridge .training .losses import (
23- SPIKY_LOSS_FACTOR ,
23+ _DEFAULT_SPIKY_LOSS_FACTOR ,
2424 create_masked_next_token_loss_function ,
2525 masked_next_token_loss ,
2626)
@@ -312,7 +312,7 @@ def test_partial_function_execution(self):
312312class TestConstants :
313313 """Test module constants."""
314314
315- def test_spiky_loss_factor (self ):
316- """Test that SPIKY_LOSS_FACTOR has expected value."""
317- assert SPIKY_LOSS_FACTOR == 10 , "SPIKY_LOSS_FACTOR should be 10"
318- assert isinstance (SPIKY_LOSS_FACTOR , int ), "SPIKY_LOSS_FACTOR should be an integer "
315+ def test_default_spiky_loss_factor (self ):
316+ """Test that _DEFAULT_SPIKY_LOSS_FACTOR has expected value."""
317+ assert _DEFAULT_SPIKY_LOSS_FACTOR == 10.0 , "_DEFAULT_SPIKY_LOSS_FACTOR should be 10.0 "
318+ assert isinstance (_DEFAULT_SPIKY_LOSS_FACTOR , float ), "_DEFAULT_SPIKY_LOSS_FACTOR should be a float "
You can’t perform that action at this time.
0 commit comments