Skip to content

Commit e8ed6b9

Browse files
committed
Fix multi-GPU parallel training being disabled
- Only disable multiprocessing for single GPU or non-CUDA devices - Enable parallel training when multiple GPUs are detected - Add clear logging to show sequential vs parallel mode The issue was that NO_MULTIPROCESSING=1 was always being set, forcing sequential training even on multi-GPU systems. Now: - Multiple GPUs: Parallel training enabled - Single GPU: Sequential mode (avoids overhead) - CPU/MPS: Sequential mode (required) This fixes the issue where 8-GPU systems were only using 1 GPU.
1 parent f6bfaee commit e8ed6b9

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

code/generate_figures.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,22 @@ def train_models(max_gpus=None):
7272
# Train models
7373
safe_print("\nTraining models...")
7474
try:
75-
# Set environment to disable tqdm and multiprocessing (which can hang in subprocess)
75+
# Set environment variables for training
7676
env = os.environ.copy()
77-
env['DISABLE_TQDM'] = '1'
78-
env['NO_MULTIPROCESSING'] = '1'
77+
env['DISABLE_TQDM'] = '1' # Disable progress bars in subprocess
78+
# Only disable multiprocessing if we have a single GPU or non-GPU device
79+
# With multiple GPUs, we want parallel training
80+
if torch.cuda.is_available():
81+
gpu_count = torch.cuda.device_count()
82+
if gpu_count <= 1:
83+
env['NO_MULTIPROCESSING'] = '1'
84+
safe_print("Single GPU detected - using sequential mode")
85+
else:
86+
safe_print(f"Multiple GPUs detected ({gpu_count}) - using parallel training")
87+
else:
88+
# Non-CUDA device (CPU or MPS)
89+
env['NO_MULTIPROCESSING'] = '1'
90+
safe_print("Non-CUDA device - using sequential mode")
7991
# Set PyTorch memory management for better GPU memory usage
8092
env['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
8193
# Pass through max GPUs limit if specified

0 commit comments

Comments
 (0)