Skip to content

Commit 40920c2

Browse files
committed
Remove 4-GPU limit and support unlimited GPUs with optional limiting
- Remove arbitrary limit of 4 GPUs - now uses ALL available GPUs by default - Add --max-gpus/-g flag to optionally limit the number of GPUs used - Update both run_llm_stylometry.sh and generate_figures.py with new flag - Pass MAX_GPUS environment variable through to main.py - Show appropriate message when GPU usage is limited vs using all GPUs Usage examples: - ./run_llm_stylometry.sh -t # Use all available GPUs - ./run_llm_stylometry.sh -t -g 2 # Limit to 2 GPUs - python code/generate_figures.py --train --max-gpus 4 # Limit to 4 GPUs This allows better scalability on large GPU clusters while still allowing users to limit GPU usage if they need to share resources.
1 parent d8defd2 commit 40920c2

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

code/generate_figures.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from llm_stylometry.cli_utils import safe_print, format_header, is_windows
2222

2323

24-
def train_models():
24+
def train_models(max_gpus=None):
2525
"""Train all models from scratch."""
2626
safe_print("\n" + "=" * 60)
2727
safe_print("Training Models from Scratch")
@@ -78,6 +78,10 @@ def train_models():
7878
env['NO_MULTIPROCESSING'] = '1'
7979
# Set PyTorch memory management for better GPU memory usage
8080
env['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
81+
# Pass through max GPUs limit if specified
82+
if max_gpus:
83+
env['MAX_GPUS'] = str(max_gpus)
84+
safe_print(f"Limiting to {max_gpus} GPU(s)")
8185
# Run without capturing output so we can see progress
8286
result = subprocess.run([sys.executable, 'code/main.py'], env=env, check=False)
8387
if result.returncode != 0:
@@ -194,6 +198,13 @@ def main():
194198
help='List available figures'
195199
)
196200

201+
parser.add_argument(
202+
'--max-gpus', '-g',
203+
type=int,
204+
help='Maximum number of GPUs to use for training (default: all available)',
205+
default=None
206+
)
207+
197208
args = parser.parse_args()
198209

199210
if args.list:
@@ -211,7 +222,7 @@ def main():
211222

212223
# Train models if requested
213224
if args.train:
214-
if not train_models():
225+
if not train_models(max_gpus=args.max_gpus):
215226
return 1
216227
# Update data path to use newly generated results
217228
args.data = 'data/model_results.pkl'

code/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,13 @@ def run_experiment(exp: Experiment, device_queue, device_type="cuda"):
300300

301301
# Use already detected device configuration
302302
if device_type == "cuda":
303-
gpu_count = min(device_count, 4)
304-
print(f"Using {gpu_count} GPUs out of {device_count} available")
303+
# Check for MAX_GPUS environment variable to optionally limit GPU usage
304+
max_gpus = int(os.environ.get('MAX_GPUS', '0')) or device_count
305+
gpu_count = min(device_count, max_gpus)
306+
if gpu_count < device_count:
307+
print(f"Using {gpu_count} GPUs (limited by MAX_GPUS) out of {device_count} available")
308+
else:
309+
print(f"Using all {gpu_count} available GPUs")
305310
elif device_type == "mps":
306311
gpu_count = 1
307312
print("Using Apple Metal Performance Shaders (MPS)")

run_llm_stylometry.sh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ OPTIONS:
3535
-h, --help Show this help message
3636
-f, --figure FIGURE Generate specific figure (1a, 1b, 2a, 2b, 3, 4, 5)
3737
-t, --train Train models from scratch before generating figures
38+
-g, --max-gpus NUM Maximum number of GPUs to use for training (default: all)
3839
-d, --data PATH Path to model_results.pkl (default: data/model_results.pkl)
3940
-o, --output DIR Output directory for figures (default: paper/figs/source)
4041
-l, --list List available figures
@@ -48,7 +49,8 @@ EXAMPLES:
4849
$0 # Setup environment and generate all figures
4950
$0 -f 1a # Generate only Figure 1A
5051
$0 -f 4 # Generate only Figure 4 (MDS plot)
51-
$0 -t # Train models from scratch, then generate figures
52+
$0 -t # Train models from scratch using all GPUs
53+
$0 -t -g 2 # Train models using only 2 GPUs
5254
$0 -l # List available figures
5355
$0 --setup-only # Only setup the environment
5456
$0 --clean # Remove environment and reinstall from scratch
@@ -278,6 +280,7 @@ setup_environment() {
278280
# Parse command line arguments
279281
FIGURE=""
280282
TRAIN=false
283+
MAX_GPUS=""
281284
DATA_PATH="data/model_results.pkl"
282285
OUTPUT_DIR="paper/figs/source"
283286
LIST_FIGURES=false
@@ -301,6 +304,10 @@ while [[ $# -gt 0 ]]; do
301304
TRAIN=true
302305
shift
303306
;;
307+
-g|--max-gpus)
308+
MAX_GPUS="$2"
309+
shift 2
310+
;;
304311
-d|--data)
305312
DATA_PATH="$2"
306313
shift 2
@@ -410,6 +417,10 @@ if [ "$TRAIN" = true ]; then
410417
PYTHON_CMD="$PYTHON_CMD --train"
411418
fi
412419

420+
if [ -n "$MAX_GPUS" ]; then
421+
PYTHON_CMD="$PYTHON_CMD --max-gpus $MAX_GPUS"
422+
fi
423+
413424
if [ "$DATA_PATH" != "data/model_results.pkl" ]; then
414425
PYTHON_CMD="$PYTHON_CMD --data $DATA_PATH"
415426
fi

0 commit comments

Comments
 (0)