11import torch
22import logging
33import sys
4+ import warnings
45from transformers import GPT2Config , GPT2LMHeadModel
56from data_utils import get_train_data_loader , get_eval_data_loader
7+
8+ # Suppress the loss_type warning from transformers
9+ warnings .filterwarnings ("ignore" , message = ".*loss_type.*unrecognized.*" )
610from model_utils import (
711 save_checkpoint ,
812 load_checkpoint ,
1721import torch .backends .cudnn as cudnn
1822from experiment import Experiment
1923import torch .multiprocessing as mp
20- from tqdm import tqdm
2124from constants import MODELS_DIR , AUTHORS , CLEANED_DATA_DIR
25+ import os
26+
27+ # Disable tqdm if running in subprocess or if explicitly disabled
28+ USE_TQDM = os .environ .get ('DISABLE_TQDM' , '0' ) != '1' and sys .stdout .isatty ()
29+ if USE_TQDM :
30+ from tqdm import tqdm
31+ else :
32+ # Simple replacement that just returns the iterable
33+ def tqdm (iterable , * args , ** kwargs ):
34+ return iterable
2235
2336logging .basicConfig (level = logging .INFO )
2437logger = logging .getLogger (__name__ )
@@ -138,6 +151,16 @@ def run_experiment(exp: Experiment, gpu_queue):
138151 train_author = exp .train_author ,
139152 )
140153
154+ # Set up mixed precision training for memory efficiency
155+ scaler = torch .amp .GradScaler ('cuda' )
156+
157+ # Enable gradient checkpointing to save memory (if supported)
158+ try :
159+ model .gradient_checkpointing_enable ()
160+ logger .info (f"[GPU { gpu_id } ] Gradient checkpointing enabled for memory efficiency" )
161+ except AttributeError :
162+ logger .info (f"[GPU { gpu_id } ] Model does not support gradient checkpointing" )
163+
141164 # Training loop
142165 for epoch in tqdm (range (start_epoch , max_epochs )):
143166 total_train_loss = 0.0
@@ -148,18 +171,27 @@ def run_experiment(exp: Experiment, gpu_queue):
148171
149172 input_ids = batch ["input_ids" ].to (device )
150173
151- # Forward pass - use input_ids as labels (HF handles shifting)
152- outputs = model (input_ids = input_ids , labels = input_ids )
153- loss = outputs .loss
174+ # Forward pass with mixed precision
175+ with torch .amp .autocast (device_type = 'cuda' , dtype = torch .float16 ):
176+ outputs = model (input_ids = input_ids , labels = input_ids )
177+ loss = outputs .loss
154178
155- # Backward pass and optimization step
156- loss .backward ()
157- optimizer .step ()
179+ # Backward pass with scaled gradients
158180 optimizer .zero_grad ()
181+ scaler .scale (loss ).backward ()
182+ scaler .step (optimizer )
183+ scaler .update ()
159184
160185 # Accumulate training loss
161186 total_train_loss += loss .item ()
162187
188+ # Delete intermediate tensors to free memory
189+ del outputs , loss
190+
191+ # Clear CUDA cache periodically
192+ if (batch_idx + 1 ) % 5 == 0 :
193+ torch .cuda .empty_cache ()
194+
163195 epochs_completed = epoch + 1
164196
165197 # Calculate average training loss
@@ -230,27 +262,46 @@ def run_experiment(exp: Experiment, gpu_queue):
230262
231263
232264if __name__ == "__main__" :
233- mp .set_start_method ("spawn" , force = True )
265+ # Check if we should run sequentially (for subprocess compatibility)
266+ USE_MULTIPROCESSING = os .environ .get ('NO_MULTIPROCESSING' , '0' ) != '1'
267+
234268 device_count = torch .cuda .device_count ()
235269 gpu_count = min (device_count , 4 )
236270 print (f"Using { gpu_count } GPUs out of { device_count } available" )
237271
238- manager = mp .Manager ()
239- gpu_queue = manager .Queue ()
240- for gpu in range (gpu_count ):
241- gpu_queue .put (gpu )
272+ if USE_MULTIPROCESSING :
273+ mp .set_start_method ("spawn" , force = True )
274+ manager = mp .Manager ()
275+ gpu_queue = manager .Queue ()
276+ for gpu in range (gpu_count ):
277+ gpu_queue .put (gpu )
242278
243- pool = mp .Pool (processes = gpu_count )
244- logger = logging .getLogger (__name__ )
279+ pool = mp .Pool (processes = gpu_count )
280+ logger = logging .getLogger (__name__ )
245281
246- def error_callback (e ):
247- logger .exception ("Unhandled error in worker, shutting down all processes" )
248- pool .terminate ()
249- sys .exit (1 )
282+ def error_callback (e ):
283+ logger .exception ("Unhandled error in worker, shutting down all processes" )
284+ pool .terminate ()
285+ sys .exit (1 )
250286
251- for exp in experiments :
252- pool .apply_async (
253- run_experiment , (exp , gpu_queue ), error_callback = error_callback
254- )
255- pool .close ()
256- pool .join ()
287+ for exp in experiments :
288+ pool .apply_async (
289+ run_experiment , (exp , gpu_queue ), error_callback = error_callback
290+ )
291+ pool .close ()
292+ pool .join ()
293+ else :
294+ # Sequential mode for subprocess compatibility
295+ print ("Running in sequential mode (multiprocessing disabled)" )
296+ import queue
297+ gpu_queue = queue .Queue ()
298+ for gpu in range (gpu_count ):
299+ gpu_queue .put (gpu )
300+
301+ for i , exp in enumerate (experiments ):
302+ print (f"Training model { i + 1 } /{ len (experiments )} : { exp .name } " )
303+ run_experiment (exp , gpu_queue )
304+ # Put GPU back in queue for next experiment
305+ if not gpu_queue .empty ():
306+ gpu_id = gpu_queue .get ()
307+ gpu_queue .put (gpu_id )
0 commit comments