From 7c781d793cdb421cda4331ff13952900ccff9a37 Mon Sep 17 00:00:00 2001 From: Adam Lee Date: Fri, 19 May 2023 17:14:00 +0900 Subject: [PATCH] Multi GPU Training --- finetune/adapter.py | 2 +- finetune/full.py | 2 +- finetune/lora.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/finetune/adapter.py b/finetune/adapter.py index f4bf266e..3ae26eaa 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -36,7 +36,7 @@ save_interval = 1000 eval_iters = 100 log_interval = 1 -devices = 1 +devices = torch.cuda.device_count() # Hyperparameters learning_rate = 9e-3 diff --git a/finetune/full.py b/finetune/full.py index 58932967..3ab673ce 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -31,7 +31,7 @@ save_interval = 1000 eval_iters = 100 log_interval = 100 -devices = 4 +devices = torch.cuda.device_count() # Hyperparameters learning_rate = 3e-5 diff --git a/finetune/lora.py b/finetune/lora.py index e00e438a..b370e836 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -50,7 +50,8 @@ def main( out_dir: str = "out/lora/alpaca", ): - fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true") + devices = torch.cuda.device_count() + fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-true") fabric.launch() fabric.seed_everything(1337 + fabric.global_rank)