@@ -512,7 +512,30 @@ def has_modular_structure(model):
512512 dict (params = hidden_gains_biases + nonhidden_params , use_muon = False ,
513513 lr = args .adamw_lr , betas = (0.9 , 0.95 ), weight_decay = 0.01 ),
514514 ]
515- optimizer = MuonWithAuxAdam (param_groups )
515+
516+ # Initialize process group for Muon optimizer (required even for single-GPU)
517+ import torch .distributed as dist
518+ if not dist .is_available () or not dist .is_initialized ():
519+ try :
520+ import os as dist_os
521+ dist_os .environ .setdefault ('MASTER_ADDR' , 'localhost' )
522+ dist_os .environ .setdefault ('MASTER_PORT' , '12355' )
523+ dist_os .environ .setdefault ('RANK' , '0' )
524+ dist_os .environ .setdefault ('WORLD_SIZE' , '1' )
525+ dist .init_process_group (backend = 'gloo' , init_method = 'env://' )
526+ print ("Initialized single-process group for Muon optimizer" )
527+ optimizer = MuonWithAuxAdam (param_groups )
528+ except Exception as e :
529+ print (f"Warning: Could not initialize process group for Muon: { e } " )
530+ print ("Falling back to AdamW optimizer" )
531+ optimizer = torch .optim .AdamW (
532+ list (encoder .parameters ()) + list (decoder .parameters ()),
533+ lr = args .learning_rate ,
534+ weight_decay = 0.000001
535+ )
536+
537+ else :
538+ optimizer = MuonWithAuxAdam (param_groups )
516539else :
517540 print ("Using AdamW optimizer" )
518541 optimizer = torch .optim .AdamW (list (encoder .parameters ()) + list (decoder .parameters ()), lr = args .learning_rate , weight_decay = 0.000001 )
@@ -528,20 +551,7 @@ def get_scheduler(optimizer, scheduler_type, num_warmup_steps, num_training_step
528551 elif scheduler_type == 'polynomial' and TRANSFORMERS_AVAILABLE :
529552 return get_polynomial_decay_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , lr_end = 0.0 , power = 1.0 ), 'step'
530553 elif scheduler_type == 'plateau' :
531- # Initialize process group if not already initialized (required for ReduceLROnPlateau)
532- import torch .distributed as dist
533- if not dist .is_available () or not dist .is_initialized ():
534- try :
535- # Initialize single-process group for non-distributed training
536- import os as dist_os
537- dist_os .environ .setdefault ('MASTER_ADDR' , 'localhost' )
538- dist_os .environ .setdefault ('MASTER_PORT' , '12355' )
539- dist_os .environ .setdefault ('RANK' , '0' )
540- dist_os .environ .setdefault ('WORLD_SIZE' , '1' )
541- dist .init_process_group (backend = 'gloo' , init_method = 'env://' )
542- print ("Initialized process group for ReduceLROnPlateau scheduler" )
543- except Exception as e :
544- print (f"Note: Could not initialize process group (running in single-process mode): { e } " )
554+ # ReduceLROnPlateau doesn't require distributed process groups - it only monitors loss values
545555 return torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , mode = 'min' , factor = 0.1 , patience = 10 ), 'epoch'
546556 elif scheduler_type == 'none' :
547557 return None , None
0 commit comments