Skip to content

Commit 7554afb

Browse files
Fix mutable default in Megatron init and IndexError on empty ModuleList (#3944)
* Fix mutable default argument in Megatron init and IndexError on empty ModuleList In megatron_lm.py, initialize() used args_defaults={} as a default parameter, which is shared across calls and can cause unexpected state persistence. In other.py, is_repeated_block() accessed module[0] without checking if the ModuleList was empty, causing IndexError for empty ModuleLists. * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 23f2ab3 commit 7554afb

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/accelerate/utils/megatron_lm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,9 @@ def finish_mpu_init():
873873

874874

875875
# initialize megatron setup
876-
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
876+
def initialize(accelerator, extra_args_provider=None, args_defaults=None):
877+
if args_defaults is None:
878+
args_defaults = {}
877879
accelerator.print("Initializing Megatron-LM")
878880
assert torch.cuda.is_available(), "Megatron requires CUDA."
879881

src/accelerate/utils/other.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ def is_repeated_blocks(module: torch.nn.Module) -> bool:
8282
is useful to determine whether we should apply regional compilation to the module.
8383
"""
8484

85-
return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
85+
return (
86+
isinstance(module, torch.nn.ModuleList)
87+
and len(module) > 0
88+
and all(isinstance(m, module[0].__class__) for m in module)
89+
)
8690

8791

8892
def has_repeated_blocks(module: torch.nn.Module) -> bool:

0 commit comments

Comments
 (0)