3232def build_trainer (cfg ):
3333 avai_trainers = TRAINER_REGISTRY .registered_names ()
3434 check_availability (cfg .TRAINER .NAME , avai_trainers )
35- if cfg .ENV .VERBOSE :
36- print ("Loading trainer: {}" .format (cfg .TRAINER .NAME ))
3735 return TRAINER_REGISTRY .get (cfg .TRAINER .NAME )(cfg )
3836
3937class TrainerBase :
@@ -98,21 +96,17 @@ def system_init(self):
9896 ## cuda setting
9997 if torch .cuda .is_available () and self .cfg .ENV .USE_CUDA :
10098 torch .backends .cudnn .benchmark = True
101- gpu_ids = self .cfg .ENV .GPU
102- if not gpu_ids :
103- raise ValueError ("ENV.GPU must contain at least one gpu id when USE_CUDA=True" )
104-
10599 if self .is_distributed :
106100 # In distributed mode, use local_rank to determine GPU
107- target_gpu = gpu_ids [self .local_rank % len (gpu_ids )]
101+ target_gpu = self . cfg . ENV . GPU [self .local_rank % len (self . cfg . ENV . GPU )]
108102 else :
109- target_gpu = gpu_ids [0 ]
110- if len (gpu_ids ) > 1 and torch .distributed .is_available ():
103+ target_gpu = self . cfg . ENV . GPU [0 ]
104+ if len (self . cfg . ENV . GPU ) > 1 and torch .distributed .is_available ():
111105 # assume torchrun/launch supplies LOCAL_RANK; fallback to rank % len(gpu_ids)
112106 local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
113107 if torch .distributed .is_initialized ():
114- local_rank = torch .distributed .get_rank () % len (gpu_ids )
115- target_gpu = gpu_ids [local_rank % len (gpu_ids )]
108+ local_rank = torch .distributed .get_rank () % len (self . cfg . ENV . GPU )
109+ target_gpu = self . cfg . ENV . GPU [local_rank % len (self . cfg . ENV . GPU )]
116110
117111 self .device = torch .device (f"cuda:{ target_gpu } " )
118112 torch .cuda .set_device (self .device )
@@ -126,17 +120,11 @@ def _init_distributed(self):
126120 # Get local rank from environment variable (set by torchrun)
127121 self .local_rank = int (os .environ .get ('LOCAL_RANK' , - 1 ))
128122
129- if self .local_rank == - 1 :
130- print ("LOCAL_RANK not found in environment. Falling back to non-distributed mode." )
131- self .cfg .ENV .DISTRIBUTED = False
132- return
133-
134123 # Initialize process group
135124 dist .init_process_group (
136125 backend = self .cfg .ENV .DIST_BACKEND ,
137126 init_method = self .cfg .ENV .DIST_URL
138127 )
139-
140128 self .rank = dist .get_rank ()
141129 self .world_size = dist .get_world_size ()
142130 self .is_distributed = True
0 commit comments