Skip to content

Commit 542a37a

Browse files
committed
checking in changes from tensorbook
1 parent 1c91dc7 commit 542a37a

File tree

5 files changed

+315
-93
lines changed

5 files changed

+315
-93
lines changed

code/eval_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,26 @@ def evaluate_model(model, eval_dataloader, device):
1313
total_loss = 0.0
1414

1515
with torch.no_grad():
16-
for batch in eval_dataloader:
17-
input_ids = batch["input_ids"].to(device)
18-
attention_mask = batch["attention_mask"].to(device)
19-
outputs = model(
20-
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
21-
)
22-
loss = outputs.loss.item()
23-
total_loss += loss
16+
for batch_idx, batch in enumerate(eval_dataloader):
17+
# Use mixed precision for evaluation too
18+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
19+
input_ids = batch["input_ids"].to(device)
20+
# Only use attention_mask if it exists
21+
if "attention_mask" in batch:
22+
attention_mask = batch["attention_mask"].to(device)
23+
outputs = model(
24+
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
25+
)
26+
else:
27+
outputs = model(input_ids=input_ids, labels=input_ids)
28+
29+
loss = outputs.loss.item()
30+
total_loss += loss
31+
32+
# Clean up memory after each batch
33+
del outputs
34+
if batch_idx % 5 == 0:
35+
torch.cuda.empty_cache()
2436

2537
return total_loss / len(eval_dataloader)
2638

code/generate_figures.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ def train_models():
3535
safe_print("Training cancelled.")
3636
return False
3737

38+
# Remove existing models directory to train from scratch
39+
import shutil
40+
models_dir = Path('models')
41+
if models_dir.exists():
42+
safe_print("\nRemoving existing models directory...")
43+
shutil.rmtree(models_dir)
44+
safe_print("Existing models removed.")
45+
46+
# Also remove existing model results file
47+
model_results_path = Path('data/model_results.pkl')
48+
if model_results_path.exists():
49+
safe_print("Removing existing model_results.pkl...")
50+
model_results_path.unlink()
51+
3852
# Prepare data if needed
3953
if not Path('data/cleaned').exists():
4054
safe_print("\nCleaning data first...")
@@ -45,16 +59,31 @@ def train_models():
4559

4660
# Train models
4761
safe_print("\nTraining models...")
48-
result = subprocess.run([sys.executable, 'code/main.py'], capture_output=True)
49-
if result.returncode != 0:
50-
safe_print(f"Error training models: {result.stderr.decode()}")
62+
try:
63+
# Set environment to disable tqdm and multiprocessing (which can hang in subprocess)
64+
env = os.environ.copy()
65+
env['DISABLE_TQDM'] = '1'
66+
env['NO_MULTIPROCESSING'] = '1'
67+
# Set PyTorch memory management for better GPU memory usage
68+
env['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
69+
# Run without capturing output so we can see progress
70+
result = subprocess.run([sys.executable, 'code/main.py'], env=env, check=False)
71+
if result.returncode != 0:
72+
safe_print(f"Error: Training script exited with code {result.returncode}")
73+
return False
74+
except Exception as e:
75+
safe_print(f"Error running training script: {e}")
5176
return False
5277

5378
# Consolidate results
5479
safe_print("\nConsolidating model results...")
55-
result = subprocess.run([sys.executable, 'code/consolidate_model_results.py'], capture_output=True)
56-
if result.returncode != 0:
57-
safe_print(f"Error consolidating results: {result.stderr.decode()}")
80+
try:
81+
result = subprocess.run([sys.executable, 'code/consolidate_model_results.py'], check=False)
82+
if result.returncode != 0:
83+
safe_print(f"Error: Consolidation script exited with code {result.returncode}")
84+
return False
85+
except Exception as e:
86+
safe_print(f"Error running consolidation script: {e}")
5887
return False
5988

6089
checkmark = "[OK]" if is_windows() else "✓"

code/main.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import torch
22
import logging
33
import sys
4+
import warnings
45
from transformers import GPT2Config, GPT2LMHeadModel
56
from data_utils import get_train_data_loader, get_eval_data_loader
7+
8+
# Suppress the loss_type warning from transformers
9+
warnings.filterwarnings("ignore", message=".*loss_type.*unrecognized.*")
610
from model_utils import (
711
save_checkpoint,
812
load_checkpoint,
@@ -17,8 +21,17 @@
1721
import torch.backends.cudnn as cudnn
1822
from experiment import Experiment
1923
import torch.multiprocessing as mp
20-
from tqdm import tqdm
2124
from constants import MODELS_DIR, AUTHORS, CLEANED_DATA_DIR
25+
import os
26+
27+
# Disable tqdm if running in subprocess or if explicitly disabled
28+
USE_TQDM = os.environ.get('DISABLE_TQDM', '0') != '1' and sys.stdout.isatty()
29+
if USE_TQDM:
30+
from tqdm import tqdm
31+
else:
32+
# Simple replacement that just returns the iterable
33+
def tqdm(iterable, *args, **kwargs):
34+
return iterable
2235

2336
logging.basicConfig(level=logging.INFO)
2437
logger = logging.getLogger(__name__)
@@ -138,6 +151,16 @@ def run_experiment(exp: Experiment, gpu_queue):
138151
train_author=exp.train_author,
139152
)
140153

154+
# Set up mixed precision training for memory efficiency
155+
scaler = torch.amp.GradScaler('cuda')
156+
157+
# Enable gradient checkpointing to save memory (if supported)
158+
try:
159+
model.gradient_checkpointing_enable()
160+
logger.info(f"[GPU {gpu_id}] Gradient checkpointing enabled for memory efficiency")
161+
except AttributeError:
162+
logger.info(f"[GPU {gpu_id}] Model does not support gradient checkpointing")
163+
141164
# Training loop
142165
for epoch in tqdm(range(start_epoch, max_epochs)):
143166
total_train_loss = 0.0
@@ -148,18 +171,27 @@ def run_experiment(exp: Experiment, gpu_queue):
148171

149172
input_ids = batch["input_ids"].to(device)
150173

151-
# Forward pass - use input_ids as labels (HF handles shifting)
152-
outputs = model(input_ids=input_ids, labels=input_ids)
153-
loss = outputs.loss
174+
# Forward pass with mixed precision
175+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
176+
outputs = model(input_ids=input_ids, labels=input_ids)
177+
loss = outputs.loss
154178

155-
# Backward pass and optimization step
156-
loss.backward()
157-
optimizer.step()
179+
# Backward pass with scaled gradients
158180
optimizer.zero_grad()
181+
scaler.scale(loss).backward()
182+
scaler.step(optimizer)
183+
scaler.update()
159184

160185
# Accumulate training loss
161186
total_train_loss += loss.item()
162187

188+
# Delete intermediate tensors to free memory
189+
del outputs, loss
190+
191+
# Clear CUDA cache periodically
192+
if (batch_idx + 1) % 5 == 0:
193+
torch.cuda.empty_cache()
194+
163195
epochs_completed = epoch + 1
164196

165197
# Calculate average training loss
@@ -230,27 +262,46 @@ def run_experiment(exp: Experiment, gpu_queue):
230262

231263

232264
if __name__ == "__main__":
233-
mp.set_start_method("spawn", force=True)
265+
# Check if we should run sequentially (for subprocess compatibility)
266+
USE_MULTIPROCESSING = os.environ.get('NO_MULTIPROCESSING', '0') != '1'
267+
234268
device_count = torch.cuda.device_count()
235269
gpu_count = min(device_count, 4)
236270
print(f"Using {gpu_count} GPUs out of {device_count} available")
237271

238-
manager = mp.Manager()
239-
gpu_queue = manager.Queue()
240-
for gpu in range(gpu_count):
241-
gpu_queue.put(gpu)
272+
if USE_MULTIPROCESSING:
273+
mp.set_start_method("spawn", force=True)
274+
manager = mp.Manager()
275+
gpu_queue = manager.Queue()
276+
for gpu in range(gpu_count):
277+
gpu_queue.put(gpu)
242278

243-
pool = mp.Pool(processes=gpu_count)
244-
logger = logging.getLogger(__name__)
279+
pool = mp.Pool(processes=gpu_count)
280+
logger = logging.getLogger(__name__)
245281

246-
def error_callback(e):
247-
logger.exception("Unhandled error in worker, shutting down all processes")
248-
pool.terminate()
249-
sys.exit(1)
282+
def error_callback(e):
283+
logger.exception("Unhandled error in worker, shutting down all processes")
284+
pool.terminate()
285+
sys.exit(1)
250286

251-
for exp in experiments:
252-
pool.apply_async(
253-
run_experiment, (exp, gpu_queue), error_callback=error_callback
254-
)
255-
pool.close()
256-
pool.join()
287+
for exp in experiments:
288+
pool.apply_async(
289+
run_experiment, (exp, gpu_queue), error_callback=error_callback
290+
)
291+
pool.close()
292+
pool.join()
293+
else:
294+
# Sequential mode for subprocess compatibility
295+
print("Running in sequential mode (multiprocessing disabled)")
296+
import queue
297+
gpu_queue = queue.Queue()
298+
for gpu in range(gpu_count):
299+
gpu_queue.put(gpu)
300+
301+
for i, exp in enumerate(experiments):
302+
print(f"Training model {i+1}/{len(experiments)}: {exp.name}")
303+
run_experiment(exp, gpu_queue)
304+
# Put GPU back in queue for next experiment
305+
if not gpu_queue.empty():
306+
gpu_id = gpu_queue.get()
307+
gpu_queue.put(gpu_id)

0 commit comments

Comments
 (0)