Skip to content

Commit e9a4e75

Browse files
authored
Add accelerator API to RPC distributed examples: ddp_rpc, parameter_server, rnn (#1371)
* Add rpc/ddp_rpc and rpc/rnn examples to CI Signed-off-by: jafraustro <[email protected]> * Add accelerator API to RPC distributed examples: - ddp_rpc - parameter_server - rnn Signed-off-by: jafraustro <[email protected]> * Update requirements for RPC examples to include numpy Signed-off-by: jafraustro <[email protected]> * Enhance GPU verification and cleanup in DDP RPC example - Added a function to verify minimum GPU count before execution. - Updated HybridModel initialization to use rank instead of device. - Ensured proper cleanup of the process group to avoid resource leaks. - Added exit message if insufficient GPUs are detected. Signed-off-by: jafraustro <[email protected]> * - Update torch version in requirements.txt - Remove CPU execution option since DDP requires 2 GPUs for this example. - Refine README.md for DDP RPC example clarity and detail Signed-off-by: jafraustro <[email protected]> --------- Signed-off-by: jafraustro <[email protected]>
1 parent 99f5c4e commit e9a4e75

File tree

9 files changed

+72
-40
lines changed

9 files changed

+72
-40
lines changed

distributed/rpc/ddp_rpc/README.md

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
Distributed DataParallel + Distributed RPC Framework Example
22

3-
The example shows how to combine Distributed DataParallel with the Distributed
4-
RPC Framework. There are two trainer nodes, 1 master node and 1 parameter
5-
server in the example.
3+
This example demonstrates how to combine Distributed DataParallel (DDP) with the Distributed RPC Framework. It requires two trainer nodes (each with a GPU), one master node, and one parameter server.
64

7-
The master node creates an embedding table on the parameter server and drives
8-
the training loop on the trainers. The model consists of a dense part
9-
(nn.Linear) replicated on the trainers via Distributed DataParallel and a
10-
sparse part (nn.EmbeddingBag) which resides on the parameter server. Each
11-
trainer performs an embedding lookup on the parameter server (using the
12-
Distributed RPC Framework) and then executes its local nn.Linear module.
13-
During the backward pass, the gradients for the dense part are aggregated via
14-
allreduce by DDP and the distributed backward pass updates the parameters for
15-
the embedding table on the parameter server.
5+
The master node initializes an embedding table on the parameter server and orchestrates the training loop across the trainers. The model is composed of a dense component (`nn.Linear`), which is replicated on the trainers using DDP, and a sparse component (`nn.EmbeddingBag`), which resides on the parameter server.
6+
7+
Each trainer performs embedding lookups on the parameter server via RPC, then processes the results through its local `nn.Linear` module. During the backward pass, DDP aggregates gradients for the dense part using allreduce, while the distributed backward pass updates the embedding table parameters on the parameter server.
168

179

1810
```

distributed/rpc/ddp_rpc/main.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
NUM_EMBEDDINGS = 100
1616
EMBEDDING_DIM = 16
1717

18+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
19+
""" verification that we have at least 2 gpus to run dist examples """
20+
has_gpu = torch.accelerator.is_available()
21+
gpu_count = torch.accelerator.device_count()
22+
return has_gpu and gpu_count >= min_gpus
1823

1924
class HybridModel(torch.nn.Module):
2025
r"""
@@ -24,15 +29,15 @@ class HybridModel(torch.nn.Module):
2429
This remote model can get a Remote Reference to the embedding table on the parameter server.
2530
"""
2631

27-
def __init__(self, remote_emb_module, device):
32+
def __init__(self, remote_emb_module, rank):
2833
super(HybridModel, self).__init__()
2934
self.remote_emb_module = remote_emb_module
30-
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
31-
self.device = device
35+
self.fc = DDP(torch.nn.Linear(16, 8).to(rank))
36+
self.rank = rank
3237

3338
def forward(self, indices, offsets):
3439
emb_lookup = self.remote_emb_module.forward(indices, offsets)
35-
return self.fc(emb_lookup.cuda(self.device))
40+
return self.fc(emb_lookup.to(self.rank))
3641

3742

3843
def _run_trainer(remote_emb_module, rank):
@@ -83,7 +88,7 @@ def get_next_batch(rank):
8388
batch_size += 1
8489

8590
offsets_tensor = torch.LongTensor(offsets)
86-
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
91+
target = torch.LongTensor(batch_size).random_(8).to(rank)
8792
yield indices, offsets_tensor, target
8893

8994
# Train for 100 epochs
@@ -145,9 +150,13 @@ def run_worker(rank, world_size):
145150
for fut in futs:
146151
fut.wait()
147152
elif rank <= 1:
153+
acc = torch.accelerator.current_accelerator()
154+
device = torch.device(acc)
155+
backend = torch.distributed.get_default_backend_for_device(device)
156+
torch.accelerator.device_index(rank)
148157
# Initialize process group for Distributed DataParallel on trainers.
149158
dist.init_process_group(
150-
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
159+
backend=backend, rank=rank, world_size=2, init_method="tcp://localhost:29500"
151160
)
152161

153162
# Initialize RPC.
@@ -172,9 +181,18 @@ def run_worker(rank, world_size):
172181

173182
# block until all rpcs finish
174183
rpc.shutdown()
184+
185+
# Clean up process group for trainers to avoid resource leaks
186+
if rank <= 1:
187+
dist.destroy_process_group()
175188

176189

177190
if __name__ == "__main__":
178191
# 2 trainers, 1 parameter server, 1 master.
179192
world_size = 4
193+
_min_gpu_count = 2
194+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
195+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
196+
exit()
180197
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
198+
print("Distributed RPC example completed successfully.")
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
torch>=1.6.0
1+
torch>=2.7.0
2+
numpy
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.7.1
2+
numpy

distributed/rpc/parameter_server/rpc_parameter_server.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,19 @@ def __init__(self, num_gpus=0):
2020
super(Net, self).__init__()
2121
print(f"Using {num_gpus} GPUs to train")
2222
self.num_gpus = num_gpus
23-
device = torch.device(
24-
"cuda:0" if torch.cuda.is_available() and self.num_gpus > 0 else "cpu")
23+
if torch.accelerator.is_available() and self.num_gpus > 0:
24+
acc = torch.accelerator.current_accelerator()
25+
device = torch.device(f'{acc}:0')
26+
else:
27+
device = torch.device("cpu")
2528
print(f"Putting first 2 convs on {str(device)}")
26-
# Put conv layers on the first cuda device
29+
# Put conv layers on the first accelerator device
2730
self.conv1 = nn.Conv2d(1, 32, 3, 1).to(device)
2831
self.conv2 = nn.Conv2d(32, 64, 3, 1).to(device)
29-
# Put rest of the network on the 2nd cuda device, if there is one
30-
if "cuda" in str(device) and num_gpus > 1:
31-
device = torch.device("cuda:1")
32+
# Put rest of the network on the 2nd accelerator device, if there is one
33+
if torch.accelerator.is_available() and self.num_gpus > 0:
34+
acc = torch.accelerator.current_accelerator()
35+
device = torch.device(f'{acc}:1')
3236

3337
print(f"Putting rest of layers on {str(device)}")
3438
self.dropout1 = nn.Dropout2d(0.25).to(device)
@@ -72,21 +76,22 @@ def call_method(method, rref, *args, **kwargs):
7276
# <foo_instance>.bar(arg1, arg2) on the remote node and getting the result
7377
# back.
7478

75-
7679
def remote_method(method, rref, *args, **kwargs):
7780
args = [method, rref] + list(args)
7881
return rpc.rpc_sync(rref.owner(), call_method, args=args, kwargs=kwargs)
7982

80-
8183
# --------- Parameter Server --------------------
8284
class ParameterServer(nn.Module):
8385
def __init__(self, num_gpus=0):
8486
super().__init__()
8587
model = Net(num_gpus=num_gpus)
8688
self.model = model
87-
self.input_device = torch.device(
88-
"cuda:0" if torch.cuda.is_available() and num_gpus > 0 else "cpu")
89-
89+
if torch.accelerator.is_available() and num_gpus > 0:
90+
acc = torch.accelerator.current_accelerator()
91+
self.input_device = torch.device(f'{acc}:0')
92+
else:
93+
self.input_device = torch.device("cpu")
94+
9095
def forward(self, inp):
9196
inp = inp.to(self.input_device)
9297
out = self.model(inp)
@@ -113,11 +118,9 @@ def get_param_rrefs(self):
113118
param_rrefs = [rpc.RRef(param) for param in self.model.parameters()]
114119
return param_rrefs
115120

116-
117121
param_server = None
118122
global_lock = Lock()
119123

120-
121124
def get_parameter_server(num_gpus=0):
122125
global param_server
123126
# Ensure that we get only one handle to the ParameterServer.
@@ -197,8 +200,11 @@ def get_accuracy(test_loader, model):
197200
model.eval()
198201
correct_sum = 0
199202
# Use GPU to evaluate if possible
200-
device = torch.device("cuda:0" if model.num_gpus > 0
201-
and torch.cuda.is_available() else "cpu")
203+
if torch.accelerator.is_available() and model.num_gpus > 0:
204+
acc = torch.accelerator.current_accelerator()
205+
device = torch.device(f'{acc}:0')
206+
else:
207+
device = torch.device("cpu")
202208
with torch.no_grad():
203209
for i, (data, target) in enumerate(test_loader):
204210
out = model(data)

distributed/rpc/rnn/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
torch
1+
torch>=2.7.1
2+
numpy

distributed/rpc/rnn/rnn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ def __init__(self, ntoken, ninp, dropout):
4343
super(EmbeddingTable, self).__init__()
4444
self.drop = nn.Dropout(dropout)
4545
self.encoder = nn.Embedding(ntoken, ninp)
46-
if torch.cuda.is_available():
47-
self.encoder = self.encoder.cuda()
46+
if torch.accelerator.is_available():
47+
device = torch.accelerator.current_accelerator()
48+
self.encoder = self.encoder.to(device)
4849
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)
4950

5051
def forward(self, input):
51-
if torch.cuda.is_available():
52-
input = input.cuda()
52+
if torch.accelerator.is_available():
53+
device = torch.accelerator.current_accelerator()
54+
input = input.to(device)
5355
return self.drop(self.encoder(input)).cpu()
5456

5557

run_distributed_examples.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,20 @@ function distributed_minGPT-ddp() {
5858
uv run bash run_example.sh mingpt/main.py || error "minGPT example failed"
5959
}
6060

61+
function distributed_rpc_ddp_rpc() {
62+
uv run main.py || error "ddp_rpc example failed"
63+
}
64+
65+
function distributed_rpc_rnn() {
66+
uv run main.py || error "rpc_rnn example failed"
67+
}
68+
6169
function run_all() {
6270
run distributed/tensor_parallelism
6371
run distributed/ddp
6472
run distributed/minGPT-ddp
73+
run distributed/rpc/ddp_rpc
74+
run distributed/rpc/rnn
6575
}
6676

6777
# by default, run all examples

utils.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function run() {
4848
if start $EXAMPLE; then
4949
# drop trailing slash (occurs due to auto completion in bash interactive mode)
5050
# replace slashes with underscores: this allows to call nested examples
51-
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@')
51+
EXAMPLE_FN=$(echo $EXAMPLE | sed "s@/\$@@" | sed 's@/@_@g')
5252
$EXAMPLE_FN
5353
fi
5454
stop $EXAMPLE

0 commit comments

Comments
 (0)