Skip to content

Commit 014391b

Browse files
committed
support configurable spiky loss threshold on re-run state machine
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 7a50d2e commit 014391b

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

docs/training/resiliency.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ config.rerun_state_machine = RerunStateMachineConfig(
441441
rerun_mode="validate_results", # or "report_determinism_stats" or "disabled"
442442
check_for_nan_in_loss=True,
443443
check_for_spiky_loss=False,
444+
spiky_loss_factor=10.0, # Adjust for your model architecture
444445
error_injection_rate=0, # For testing only
445446
error_injection_type="transient_error",
446447
)
@@ -453,6 +454,7 @@ config.rerun_state_machine = RerunStateMachineConfig(
453454
| `rerun_mode` | `str` | `"disabled"` | Operating mode: `"disabled"`, `"validate_results"`, or `"report_determinism_stats"` |
454455
| `check_for_nan_in_loss` | `bool` | `True` | Check for NaN values in loss |
455456
| `check_for_spiky_loss` | `bool` | `False` | Check for unexpectedly large loss values |
457+
| `spiky_loss_factor` | `float` | `10.0` | Factor for spiky loss detection. Loss is flagged if it exceeds this multiple of max observed loss. Larger models may need higher values (e.g., 15-20 for 70B+). |
456458
| `error_injection_rate` | `int` | `0` | Rate for injecting test errors (testing only) |
457459
| `error_injection_type` | `str` | `"transient_error"` | Type of error to inject for testing |
458460

src/megatron/bridge/training/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ class RerunStateMachineConfig:
223223
check_for_spiky_loss: bool = False
224224
"""Check for spiky loss."""
225225

226+
spiky_loss_factor: float = 10.0
227+
"""Factor for detecting spiky loss. A loss is considered spiky if it exceeds
228+
this multiple of the max observed loss over the sample window."""
229+
226230

227231
@dataclass(kw_only=True)
228232
class DataloaderConfig:

src/megatron/bridge/training/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def init_rerun_state(rerun_state_machine_config: RerunStateMachineConfig) -> Non
222222
RerunDiagnostic,
223223
RerunErrorInjector,
224224
RerunMode,
225+
get_rerun_state_machine,
225226
initialize_rerun_state_machine,
226227
)
227228

@@ -242,6 +243,10 @@ def state_restore_func(state_dict):
242243
),
243244
)
244245

246+
# Store config on the singleton for use in loss validation
247+
rsm = get_rerun_state_machine()
248+
rsm.spiky_loss_factor = rerun_state_machine_config.spiky_loss_factor
249+
245250

246251
def set_jit_fusion_options(model_config: GPTModelProvider | T5ModelProvider, micro_batch_size: int) -> None:
247252
"""Set PyTorch JIT layer fusion options and warmup JIT functions.

src/megatron/bridge/training/losses.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from megatron.core.rerun_state_machine import get_rerun_state_machine
2020

2121

22-
SPIKY_LOSS_FACTOR: int = 10
22+
_DEFAULT_SPIKY_LOSS_FACTOR: float = 10.0
2323

2424

2525
def create_masked_next_token_loss_function(
@@ -86,11 +86,12 @@ def masked_next_token_loss(
8686
)
8787
# Check for spiky loss
8888
if check_for_spiky_loss:
89+
spiky_loss_factor = getattr(rerun_state_machine, "spiky_loss_factor", _DEFAULT_SPIKY_LOSS_FACTOR)
8990
rerun_state_machine.validate_result(
9091
result=loss,
9192
rejection_func=partial(
9293
rerun_state_machine.is_unexpectedly_large,
93-
threshold=SPIKY_LOSS_FACTOR,
94+
threshold=spiky_loss_factor,
9495
context="loss",
9596
),
9697
message="Spiky loss",

tests/unit_tests/training/test_losses.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
from megatron.bridge.training.losses import (
23-
SPIKY_LOSS_FACTOR,
23+
_DEFAULT_SPIKY_LOSS_FACTOR,
2424
create_masked_next_token_loss_function,
2525
masked_next_token_loss,
2626
)
@@ -312,7 +312,7 @@ def test_partial_function_execution(self):
312312
class TestConstants:
313313
"""Test module constants."""
314314

315-
def test_spiky_loss_factor(self):
316-
"""Test that SPIKY_LOSS_FACTOR has expected value."""
317-
assert SPIKY_LOSS_FACTOR == 10, "SPIKY_LOSS_FACTOR should be 10"
318-
assert isinstance(SPIKY_LOSS_FACTOR, int), "SPIKY_LOSS_FACTOR should be an integer"
315+
def test_default_spiky_loss_factor(self):
316+
"""Test that _DEFAULT_SPIKY_LOSS_FACTOR has expected value."""
317+
assert _DEFAULT_SPIKY_LOSS_FACTOR == 10.0, "_DEFAULT_SPIKY_LOSS_FACTOR should be 10.0"
318+
assert isinstance(_DEFAULT_SPIKY_LOSS_FACTOR, float), "_DEFAULT_SPIKY_LOSS_FACTOR should be a float"

0 commit comments

Comments
 (0)