-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Adding torch accelerator to ddp-tutorial-series example #1376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,9 +17,20 @@ def ddp_setup(rank, world_size): | |
world_size: Total number of processes | ||
""" | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
torch.cuda.set_device(rank) | ||
init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
os.environ["MASTER_PORT"] = "12455" | ||
|
||
|
||
rank = int(os.environ["LOCAL_RANK"]) | ||
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() | ||
device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no such API What is it doing? You did set index 2 lines above... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does not make sense to call context manager without There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I'm making the changes, thanks |
||
print(f"Running on rank {rank} on device {device}") | ||
else: | ||
device = torch.device("cpu") | ||
print(f"Running on device {device}") | ||
|
||
backend = torch.distributed.get_default_backend_for_device(device) | ||
|
||
class Trainer: | ||
def __init__( | ||
|
@@ -100,5 +111,5 @@ def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_s | |
parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | ||
args = parser.parse_args() | ||
|
||
world_size = torch.cuda.device_count() | ||
world_size = torch.accelerator.device_count() | ||
mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,8 +11,20 @@ | |||||
|
||||||
|
||||||
def ddp_setup(): | ||||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||||||
init_process_group(backend="nccl") | ||||||
rank = int(os.environ["LOCAL_RANK"]) | ||||||
if torch.accelerator.is_available(): | ||||||
device_type = torch.accelerator.current_accelerator() | ||||||
device = torch.device(f"{device_type}:{rank}") | ||||||
torch.accelerator.device_index(rank) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||||||
print(f"Running on rank {rank} on device {device}") | ||||||
else: | ||||||
device = torch.device("cpu") | ||||||
print(f"Running on device {device}") | ||||||
|
||||||
backend = torch.distributed.get_default_backend_for_device(device) | ||||||
torch.distributed.init_process_group(backend=backend, device_id=device) | ||||||
return device | ||||||
|
||||||
|
||||||
class Trainer: | ||||||
def __init__( | ||||||
|
@@ -22,6 +34,7 @@ def __init__( | |||||
optimizer: torch.optim.Optimizer, | ||||||
save_every: int, | ||||||
snapshot_path: str, | ||||||
device | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to have type designation here:
Suggested change
|
||||||
) -> None: | ||||||
self.gpu_id = int(os.environ["LOCAL_RANK"]) | ||||||
self.model = model.to(self.gpu_id) | ||||||
|
@@ -30,14 +43,15 @@ def __init__( | |||||
self.save_every = save_every | ||||||
self.epochs_run = 0 | ||||||
self.snapshot_path = snapshot_path | ||||||
self.device = device | ||||||
if os.path.exists(snapshot_path): | ||||||
print("Loading snapshot") | ||||||
self._load_snapshot(snapshot_path) | ||||||
|
||||||
self.model = DDP(self.model, device_ids=[self.gpu_id]) | ||||||
|
||||||
def _load_snapshot(self, snapshot_path): | ||||||
loc = f"cuda:{self.gpu_id}" | ||||||
loc = str(self.device) | ||||||
snapshot = torch.load(snapshot_path, map_location=loc) | ||||||
self.model.load_state_dict(snapshot["MODEL_STATE"]) | ||||||
self.epochs_run = snapshot["EPOCHS_RUN"] | ||||||
|
@@ -92,10 +106,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): | |||||
|
||||||
|
||||||
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): | ||||||
ddp_setup() | ||||||
device = ddp_setup() | ||||||
dataset, model, optimizer = load_train_objs() | ||||||
train_data = prepare_dataloader(dataset, batch_size) | ||||||
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) | ||||||
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) | ||||||
trainer.train(total_epochs) | ||||||
destroy_process_group() | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,20 @@ | |
|
||
|
||
def ddp_setup(): | ||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
init_process_group(backend="nccl") | ||
rank = int(os.environ["LOCAL_RANK"]) | ||
if torch.accelerator.is_available(): | ||
device_type = torch.accelerator.current_accelerator() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comments as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not addressed. |
||
device: torch.device = torch.device(f"{device_type}:{rank}") | ||
torch.accelerator.device_index(rank) | ||
print(f"Running on rank {rank} on device {device}") | ||
backend = torch.distributed.get_default_backend_for_device(device) | ||
torch.distributed.init_process_group(backend=backend) | ||
return device_type | ||
else: | ||
device = torch.device("cpu") | ||
print(f"Running on device {device}") | ||
torch.distributed.init_process_group(backend="gloo") | ||
return device | ||
|
||
class Trainer: | ||
def __init__( | ||
|
@@ -22,6 +34,7 @@ def __init__( | |
optimizer: torch.optim.Optimizer, | ||
save_every: int, | ||
snapshot_path: str, | ||
device | ||
) -> None: | ||
self.local_rank = int(os.environ["LOCAL_RANK"]) | ||
self.global_rank = int(os.environ["RANK"]) | ||
|
@@ -31,14 +44,15 @@ def __init__( | |
self.save_every = save_every | ||
self.epochs_run = 0 | ||
self.snapshot_path = snapshot_path | ||
self.device = device | ||
if os.path.exists(snapshot_path): | ||
print("Loading snapshot") | ||
self._load_snapshot(snapshot_path) | ||
|
||
self.model = DDP(self.model, device_ids=[self.local_rank]) | ||
|
||
def _load_snapshot(self, snapshot_path): | ||
loc = f"cuda:{self.local_rank}" | ||
loc = str(self.device) | ||
snapshot = torch.load(snapshot_path, map_location=loc) | ||
self.model.load_state_dict(snapshot["MODEL_STATE"]) | ||
self.epochs_run = snapshot["EPOCHS_RUN"] | ||
|
@@ -93,10 +107,10 @@ def prepare_dataloader(dataset: Dataset, batch_size: int): | |
|
||
|
||
def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"): | ||
ddp_setup() | ||
device = ddp_setup() | ||
dataset, model, optimizer = load_train_objs() | ||
train_data = prepare_dataloader(dataset, batch_size) | ||
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) | ||
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, device) | ||
trainer.train(total_epochs) | ||
destroy_process_group() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
torch>=1.11.0 | ||
torch>=2.7 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# /bin/bash | ||
# bash run_example.sh {file_to_run.py} {num_gpus} | ||
# where file_to_run = example to run. Default = 'example.py' | ||
# num_gpus = num local gpus to use (must be at least 2). Default = 2 | ||
|
||
# samples to run include: | ||
# example.py | ||
|
||
echo "Launching ${1:-example.py} with ${2:-2} gpus" | ||
torchrun --nnodes=1 --nproc_per_node=${2:-2} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-example.py} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still different port number.