Skip to content

Commit 0c0ada6

Browse files
author
dmoi
committed
ddp training working
1 parent 1529042 commit 0c0ada6

File tree

7 files changed

+2765
-4148
lines changed

7 files changed

+2765
-4148
lines changed

foldtree2/learn_monodecoder.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,30 @@ def has_modular_structure(model):
512512
dict(params=hidden_gains_biases+nonhidden_params, use_muon=False,
513513
lr=args.adamw_lr, betas=(0.9, 0.95), weight_decay=0.01),
514514
]
515-
optimizer = MuonWithAuxAdam(param_groups)
515+
516+
# Initialize process group for Muon optimizer (required even for single-GPU)
517+
import torch.distributed as dist
518+
if not dist.is_available() or not dist.is_initialized():
519+
try:
520+
import os as dist_os
521+
dist_os.environ.setdefault('MASTER_ADDR', 'localhost')
522+
dist_os.environ.setdefault('MASTER_PORT', '12355')
523+
dist_os.environ.setdefault('RANK', '0')
524+
dist_os.environ.setdefault('WORLD_SIZE', '1')
525+
dist.init_process_group(backend='gloo', init_method='env://')
526+
print("Initialized single-process group for Muon optimizer")
527+
optimizer = MuonWithAuxAdam(param_groups)
528+
except Exception as e:
529+
print(f"Warning: Could not initialize process group for Muon: {e}")
530+
print("Falling back to AdamW optimizer")
531+
optimizer = torch.optim.AdamW(
532+
list(encoder.parameters()) + list(decoder.parameters()),
533+
lr=args.learning_rate,
534+
weight_decay=0.000001
535+
)
536+
537+
else:
538+
optimizer = MuonWithAuxAdam(param_groups)
516539
else:
517540
print("Using AdamW optimizer")
518541
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=args.learning_rate, weight_decay=0.000001)
@@ -528,20 +551,7 @@ def get_scheduler(optimizer, scheduler_type, num_warmup_steps, num_training_step
528551
elif scheduler_type == 'polynomial' and TRANSFORMERS_AVAILABLE:
529552
return get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, lr_end=0.0, power=1.0), 'step'
530553
elif scheduler_type == 'plateau':
531-
# Initialize process group if not already initialized (required for ReduceLROnPlateau)
532-
import torch.distributed as dist
533-
if not dist.is_available() or not dist.is_initialized():
534-
try:
535-
# Initialize single-process group for non-distributed training
536-
import os as dist_os
537-
dist_os.environ.setdefault('MASTER_ADDR', 'localhost')
538-
dist_os.environ.setdefault('MASTER_PORT', '12355')
539-
dist_os.environ.setdefault('RANK', '0')
540-
dist_os.environ.setdefault('WORLD_SIZE', '1')
541-
dist.init_process_group(backend='gloo', init_method='env://')
542-
print("Initialized process group for ReduceLROnPlateau scheduler")
543-
except Exception as e:
544-
print(f"Note: Could not initialize process group (running in single-process mode): {e}")
554+
# ReduceLROnPlateau doesn't require distributed process groups - it only monitors loss values
545555
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10), 'epoch'
546556
elif scheduler_type == 'none':
547557
return None, None

0 commit comments

Comments
 (0)