Skip to content

Commit d8defd2

Browse files
committed
Add support for CPU, single-GPU, and multi-GPU training
- Detect device type (CUDA/MPS/CPU) and count at startup - Support Apple Metal Performance Shaders (MPS) for M1/M2 Macs - Support CPU-only training (with warning about speed) - Support multi-GPU parallel training (up to 4 GPUs) - Conditionally use mixed precision only on CUDA devices - Update logging to show device type instead of assuming GPU - Add device detection to run_llm_stylometry.sh - Show available devices in generate_figures.py training prompt The code now automatically adapts to available hardware: - Multi-GPU systems: Uses parallel training across GPUs - Single GPU/MPS: Uses single device training - CPU only: Falls back to CPU with performance warning
1 parent e2e16f9 commit d8defd2

File tree

3 files changed

+117
-43
lines changed

3 files changed

+117
-43
lines changed

code/generate_figures.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,20 @@ def train_models():
2727
safe_print("Training Models from Scratch")
2828
safe_print("=" * 60)
2929
warning = "[WARNING]" if is_windows() else "⚠️"
30+
# Check device availability
31+
import torch
32+
device_info = ""
33+
if torch.cuda.is_available():
34+
gpu_count = torch.cuda.device_count()
35+
device_info = f"CUDA GPUs available: {gpu_count}"
36+
elif torch.backends.mps.is_available():
37+
device_info = "Apple Metal Performance Shaders (MPS) available"
38+
else:
39+
device_info = "CPU only (training will be slow)"
40+
3041
safe_print(f"\n{warning} Warning: This will train 80 models (8 authors × 10 seeds)")
31-
safe_print(" This requires a CUDA GPU and will take several hours.")
42+
safe_print(f" Device: {device_info}")
43+
safe_print(" Training time depends on hardware (hours on GPU, days on CPU)")
3244

3345
response = input("\nProceed with training? [y/N]: ")
3446
if response.lower() != 'y':

code/main.py

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,20 @@ def tqdm(iterable, *args, **kwargs):
3636
logging.basicConfig(level=logging.INFO)
3737
logger = logging.getLogger(__name__)
3838

39-
if not torch.cuda.is_available():
40-
raise Exception("No GPU available")
39+
# Detect available devices
40+
def get_device_info():
41+
"""Detect and return device configuration."""
42+
if torch.cuda.is_available():
43+
device_count = torch.cuda.device_count()
44+
return "cuda", device_count
45+
elif torch.backends.mps.is_available():
46+
# Apple Metal Performance Shaders (MPS) backend
47+
return "mps", 1
48+
else:
49+
return "cpu", 1
50+
51+
device_type, device_count = get_device_info()
52+
logger.info(f"Device type: {device_type}, Count: {device_count}")
4153

4254
experiments = []
4355
for seed in range(10):
@@ -51,16 +63,26 @@ def tqdm(iterable, *args, **kwargs):
5163
)
5264

5365

54-
def run_experiment(exp: Experiment, gpu_queue):
66+
def run_experiment(exp: Experiment, device_queue, device_type="cuda"):
5567
try:
5668
logging.basicConfig(level=logging.INFO)
5769
logger = logging.getLogger(__name__)
5870

59-
# Get an available GPU id
60-
gpu_id = gpu_queue.get()
71+
# Get an available device id
72+
device_id = device_queue.get() if device_queue else 0
6173
logger.info(f"Starting experiment: {exp.name}")
62-
torch.cuda.set_device(gpu_id)
63-
device = torch.device("cuda", index=gpu_id)
74+
75+
# Set up device based on type
76+
if device_type == "cuda":
77+
torch.cuda.set_device(device_id)
78+
device = torch.device("cuda", index=device_id)
79+
device_label = f"GPU {device_id}"
80+
elif device_type == "mps":
81+
device = torch.device("mps")
82+
device_label = "MPS"
83+
else:
84+
device = torch.device("cpu")
85+
device_label = "CPU"
6486

6587
# Initialize tokenizer directly using get_tokenizer
6688
tokenizer = get_tokenizer(exp.tokenizer_name)
@@ -82,7 +104,7 @@ def run_experiment(exp: Experiment, gpu_queue):
82104
excluded_train_path=exp.excluded_train_path,
83105
)
84106
logger.info(
85-
f"[GPU {gpu_id}] Number of training batches: {len(train_dataloader)}"
107+
f"[{device_label}] Number of training batches: {len(train_dataloader)}"
86108
)
87109

88110
# Set up eval dataloaders
@@ -130,7 +152,7 @@ def run_experiment(exp: Experiment, gpu_queue):
130152
start_epoch = 0
131153

132154
logger.info(
133-
f"[GPU {gpu_id}] Total number of non-embedding parameters: {count_non_embedding_params(model)}"
155+
f"[{device_label}] Total number of non-embedding parameters: {count_non_embedding_params(model)}"
134156
)
135157

136158
# Initial evaluation (epochs_complete = 0)
@@ -151,15 +173,16 @@ def run_experiment(exp: Experiment, gpu_queue):
151173
train_author=exp.train_author,
152174
)
153175

154-
# Set up mixed precision training for memory efficiency
155-
scaler = torch.amp.GradScaler('cuda')
176+
# Set up mixed precision training if supported
177+
use_amp = device_type == "cuda"
178+
scaler = torch.amp.GradScaler('cuda') if use_amp else None
156179

157180
# Enable gradient checkpointing to save memory (if supported)
158181
try:
159182
model.gradient_checkpointing_enable()
160-
logger.info(f"[GPU {gpu_id}] Gradient checkpointing enabled for memory efficiency")
183+
logger.info(f"[{device_label}] Gradient checkpointing enabled for memory efficiency")
161184
except AttributeError:
162-
logger.info(f"[GPU {gpu_id}] Model does not support gradient checkpointing")
185+
logger.info(f"[{device_label}] Model does not support gradient checkpointing")
163186

164187
# Training loop
165188
for epoch in tqdm(range(start_epoch, max_epochs)):
@@ -171,16 +194,24 @@ def run_experiment(exp: Experiment, gpu_queue):
171194

172195
input_ids = batch["input_ids"].to(device)
173196

174-
# Forward pass with mixed precision
175-
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
197+
# Forward pass with or without mixed precision
198+
if use_amp:
199+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
200+
outputs = model(input_ids=input_ids, labels=input_ids)
201+
loss = outputs.loss
202+
else:
176203
outputs = model(input_ids=input_ids, labels=input_ids)
177204
loss = outputs.loss
178205

179-
# Backward pass with scaled gradients
206+
# Backward pass with or without mixed precision
180207
optimizer.zero_grad()
181-
scaler.scale(loss).backward()
182-
scaler.step(optimizer)
183-
scaler.update()
208+
if use_amp:
209+
scaler.scale(loss).backward()
210+
scaler.step(optimizer)
211+
scaler.update()
212+
else:
213+
loss.backward()
214+
optimizer.step()
184215

185216
# Accumulate training loss
186217
total_train_loss += loss.item()
@@ -230,11 +261,12 @@ def run_experiment(exp: Experiment, gpu_queue):
230261
train_author=exp.train_author,
231262
)
232263

233-
# Force memory cleanup between evaluations
234-
torch.cuda.empty_cache()
264+
# Force memory cleanup between evaluations (CUDA only)
265+
if device_type == "cuda":
266+
torch.cuda.empty_cache()
235267

236268
# Build log message for console output
237-
log_message = f"[GPU {gpu_id}] Epoch {epochs_completed}/{max_epochs}: training loss = {train_loss:.4f}"
269+
log_message = f"[{device_label}] Epoch {epochs_completed}/{max_epochs}: training loss = {train_loss:.4f}"
238270
for name, loss in eval_losses.items():
239271
log_message += f", {name}: {loss:.4f}"
240272
logger.info(log_message)
@@ -249,13 +281,14 @@ def run_experiment(exp: Experiment, gpu_queue):
249281
# Early stopping after completing epoch (retain logs and checkpoints)
250282
if train_loss <= stop_train_loss and min_epochs <= epochs_completed:
251283
logger.info(
252-
f"[GPU {gpu_id}] Training loss {train_loss:.4f} below threshold {stop_train_loss}. Stopping training."
284+
f"[{device_label}] Training loss {train_loss:.4f} below threshold {stop_train_loss}. Stopping training."
253285
)
254286
break
255-
logger.info(f"[GPU {gpu_id}] Training complete for {modelname}")
287+
logger.info(f"[{device_label}] Training complete for {modelname}")
256288

257289
# Return the GPU id to the queue
258-
gpu_queue.put(gpu_id)
290+
if device_queue:
291+
device_queue.put(device_id)
259292
except Exception:
260293
logger.exception(f"Error in experiment {exp.name}")
261294
raise
@@ -265,16 +298,24 @@ def run_experiment(exp: Experiment, gpu_queue):
265298
# Check if we should run sequentially (for subprocess compatibility)
266299
USE_MULTIPROCESSING = os.environ.get('NO_MULTIPROCESSING', '0') != '1'
267300

268-
device_count = torch.cuda.device_count()
269-
gpu_count = min(device_count, 4)
270-
print(f"Using {gpu_count} GPUs out of {device_count} available")
301+
# Use already detected device configuration
302+
if device_type == "cuda":
303+
gpu_count = min(device_count, 4)
304+
print(f"Using {gpu_count} GPUs out of {device_count} available")
305+
elif device_type == "mps":
306+
gpu_count = 1
307+
print("Using Apple Metal Performance Shaders (MPS)")
308+
else:
309+
gpu_count = 1
310+
print("Using CPU for training (this will be slow)")
271311

272-
if USE_MULTIPROCESSING:
312+
if USE_MULTIPROCESSING and device_type == "cuda" and gpu_count > 1:
313+
# Only use multiprocessing for multiple CUDA GPUs
273314
mp.set_start_method("spawn", force=True)
274315
manager = mp.Manager()
275-
gpu_queue = manager.Queue()
316+
device_queue = manager.Queue()
276317
for gpu in range(gpu_count):
277-
gpu_queue.put(gpu)
318+
device_queue.put(gpu)
278319

279320
pool = mp.Pool(processes=gpu_count)
280321
logger = logging.getLogger(__name__)
@@ -286,22 +327,27 @@ def error_callback(e):
286327

287328
for exp in experiments:
288329
pool.apply_async(
289-
run_experiment, (exp, gpu_queue), error_callback=error_callback
330+
run_experiment, (exp, device_queue, device_type), error_callback=error_callback
290331
)
291332
pool.close()
292333
pool.join()
293334
else:
294-
# Sequential mode for subprocess compatibility
335+
# Sequential mode for subprocess compatibility or single device
295336
print("Running in sequential mode (multiprocessing disabled)")
296-
import queue
297-
gpu_queue = queue.Queue()
298-
for gpu in range(gpu_count):
299-
gpu_queue.put(gpu)
337+
if device_type == "cuda" and gpu_count > 1:
338+
# Multiple GPUs but running sequentially
339+
import queue
340+
device_queue = queue.Queue()
341+
for gpu in range(gpu_count):
342+
device_queue.put(gpu)
343+
else:
344+
# Single device or non-CUDA
345+
device_queue = None
300346

301347
for i, exp in enumerate(experiments):
302348
print(f"Training model {i+1}/{len(experiments)}: {exp.name}")
303-
run_experiment(exp, gpu_queue)
304-
# Put GPU back in queue for next experiment
305-
if not gpu_queue.empty():
306-
gpu_id = gpu_queue.get()
307-
gpu_queue.put(gpu_id)
349+
run_experiment(exp, device_queue, device_type)
350+
# For multi-GPU sequential mode, rotate through GPUs
351+
if device_queue and not device_queue.empty():
352+
device_id = device_queue.get()
353+
device_queue.put(device_id)

run_llm_stylometry.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,22 @@ if [ "$SETUP_ONLY" = true ]; then
381381
exit 0
382382
fi
383383

384+
# Detect available compute devices
385+
print_info "Detecting available compute devices..."
386+
DEVICE_INFO=$(python -c "
387+
import torch
388+
if torch.cuda.is_available():
389+
n = torch.cuda.device_count()
390+
names = [torch.cuda.get_device_name(i) for i in range(n)]
391+
print(f'CUDA GPUs: {n} device(s) - {names[0] if n > 0 else \"Unknown\"}')
392+
elif torch.backends.mps.is_available():
393+
print('Apple Metal Performance Shaders (MPS)')
394+
else:
395+
import multiprocessing
396+
print(f'CPU only ({multiprocessing.cpu_count()} cores)')
397+
" 2>/dev/null || echo "Could not detect device")
398+
print_info "Device: $DEVICE_INFO"
399+
384400
# Build the Python command
385401
PYTHON_CMD="python code/generate_figures.py"
386402

0 commit comments

Comments
 (0)