Skip to content

Commit 1f85172

Browse files
authored
Remote storage pipeline (#899)
* PYTHON: updated remote storage example with pipeline and better README Signed-off-by: Timothy Stamler <[email protected]>
1 parent 5ab9950 commit 1f85172

File tree

6 files changed

+261
-61
lines changed

6 files changed

+261
-61
lines changed

examples/python/remote_storage_example/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,21 @@ The system automatically selects the best available storage backend:
108108
1. Initiator sends memory descriptors to target
109109
2. Target performs storage-to-memory or memory-to-storage operations
110110
3. Data is transferred between initiator and target memory
111+
112+
Remote reads are implemented as a read from storage followed by a network write.
113+
114+
Remote writes are implemented as a read from network following by a storage write.
115+
116+
### Pipelining
117+
118+
To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme. This pipeline for remote writes is implemented as a simple read into NIXL descriptors from the network, followed by a write to storage (also through NIXL, but a different plugin). A remote read is similar, just reading into NIXL descriptors from storage and then writing to network.
119+
120+
![Remote Operation Pipelines](storage_pipelines.png)
121+
122+
### Performance Tips
123+
124+
For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables.
125+
126+
First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md). Also a reminder to check that your GDS setup is running true GPU-direct IO and not in compatibility mode.
127+
128+
On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed.
16.8 KB
Loading

examples/python/remote_storage_example/nixl_p2p_storage_example.py

Lines changed: 218 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes.
1919
"""
2020

21+
import concurrent.futures
2122
import time
2223

2324
import nixl_storage_utils as nsu
@@ -27,14 +28,20 @@
2728
logger = get_logger(__name__)
2829

2930

30-
def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation):
31-
handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name)
31+
def execute_transfer(
32+
my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[]
33+
):
34+
handle = my_agent.initialize_xfer(
35+
operation, local_descs, remote_descs, remote_name, backends=use_backends
36+
)
3237
my_agent.transfer(handle)
3338
nsu.wait_for_transfer(my_agent, handle)
3439
my_agent.release_xfer_handle(handle)
3540

3641

37-
def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name):
42+
def remote_storage_transfer(
43+
my_agent, my_mem_descs, operation, remote_agent_name, iterations
44+
):
3845
"""Initiate remote memory transfer."""
3946
if operation != "READ" and operation != "WRITE":
4047
logger.error("Invalid operation, exiting")
@@ -45,14 +52,24 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name
4552
else:
4653
operation = b"READ"
4754

55+
iterations_str = bytes(f"{iterations:04d}", "utf-8")
4856
# Send the descriptors that you want to read into or write from
49-
logger.info(f"Sending {operation} request to {remote_agent_name}")
57+
logger.info(
58+
"Sending %s request to %s", operation.decode("utf-8"), remote_agent_name
59+
)
5060
test_descs_str = my_agent.get_serialized_descs(my_mem_descs)
51-
my_agent.send_notif(remote_agent_name, operation + test_descs_str)
61+
62+
start_time = time.time()
63+
64+
my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str)
5265

5366
while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"):
5467
continue
5568

69+
elapsed = time.time() - start_time
70+
71+
logger.info("Time for %d iterations: %f seconds", iterations, elapsed)
72+
5673

5774
def connect_to_agents(my_agent, agents_file):
5875
target_agents = []
@@ -66,26 +83,154 @@ def connect_to_agents(my_agent, agents_file):
6683
my_agent.fetch_remote_metadata(parts[0], parts[1], int(parts[2]))
6784

6885
while my_agent.check_remote_metadata(parts[0]) is False:
69-
logger.info(f"Waiting for remote metadata for {parts[0]}...")
86+
logger.info("Waiting for remote metadata for %s...", parts[0])
7087
time.sleep(0.2)
7188

72-
logger.info(f"Remote metadata for {parts[0]} fetched")
89+
logger.info("Remote metadata for %s fetched", parts[0])
7390
else:
74-
logger.error(f"Invalid line in {agents_file}: {line}")
91+
logger.error("Invalid line in %s: %s", agents_file, line)
7592
exit(-1)
7693

7794
logger.info("All remote metadata fetched")
7895

7996
return target_agents
8097

8198

99+
def pipeline_reads(
100+
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
101+
):
102+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
103+
n = 0
104+
s = 0
105+
futures = []
106+
107+
while n < iterations or s < iterations:
108+
if s == 0:
109+
futures.append(
110+
executor.submit(
111+
execute_transfer,
112+
my_agent,
113+
my_mem_descs,
114+
my_file_descs,
115+
my_agent.name,
116+
"READ",
117+
)
118+
)
119+
s += 1
120+
continue
121+
122+
if s == iterations:
123+
futures.append(
124+
executor.submit(
125+
execute_transfer,
126+
my_agent,
127+
my_mem_descs,
128+
sent_descs,
129+
req_agent,
130+
"WRITE",
131+
)
132+
)
133+
n += 1
134+
continue
135+
136+
# Do two storage and network in parallel
137+
futures.append(
138+
executor.submit(
139+
execute_transfer,
140+
my_agent,
141+
my_mem_descs,
142+
my_file_descs,
143+
my_agent.name,
144+
"READ",
145+
)
146+
)
147+
futures.append(
148+
executor.submit(
149+
execute_transfer,
150+
my_agent,
151+
my_mem_descs,
152+
sent_descs,
153+
req_agent,
154+
"WRITE",
155+
)
156+
)
157+
s += 1
158+
n += 1
159+
160+
_, not_done = concurrent.futures.wait(
161+
futures, return_when=concurrent.futures.ALL_COMPLETED
162+
)
163+
assert not not_done
164+
165+
166+
def pipeline_writes(
167+
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
168+
):
169+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
170+
n = 0
171+
s = 1
172+
futures = []
173+
174+
futures.append(
175+
executor.submit(
176+
execute_transfer,
177+
my_agent,
178+
my_mem_descs,
179+
sent_descs,
180+
req_agent,
181+
"READ",
182+
)
183+
)
184+
while n < iterations or s < iterations:
185+
if s == iterations:
186+
futures.append(
187+
executor.submit(
188+
execute_transfer,
189+
my_agent,
190+
my_mem_descs,
191+
my_file_descs,
192+
my_agent.name,
193+
"WRITE",
194+
)
195+
)
196+
n += 1
197+
continue
198+
199+
# Do two storage and network in parallel
200+
futures.append(
201+
executor.submit(
202+
execute_transfer,
203+
my_agent,
204+
my_mem_descs,
205+
sent_descs,
206+
req_agent,
207+
"READ",
208+
)
209+
)
210+
futures.append(
211+
executor.submit(
212+
execute_transfer,
213+
my_agent,
214+
my_mem_descs,
215+
my_file_descs,
216+
my_agent.name,
217+
"WRITE",
218+
)
219+
)
220+
s += 1
221+
n += 1
222+
223+
_, not_done = concurrent.futures.wait(
224+
futures, return_when=concurrent.futures.ALL_COMPLETED
225+
)
226+
assert not not_done
227+
228+
82229
def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
83230
"""Handle remote memory and storage transfers as target."""
84231
# Wait for initiator to send list of memory descriptors
85232
notifs = my_agent.get_new_notifs()
86233

87-
logger.info("Waiting for a remote transfer request...")
88-
89234
while len(notifs) == 0:
90235
notifs = my_agent.get_new_notifs()
91236

@@ -101,57 +246,65 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
101246
logger.error("Invalid operation, exiting")
102247
exit(-1)
103248

104-
sent_descs = my_agent.deserialize_descs(recv_msg[4:])
249+
iterations = int(recv_msg[4:8])
105250

106-
logger.info("Checking to ensure metadata is loaded...")
107-
while my_agent.check_remote_metadata(req_agent, sent_descs) is False:
108-
continue
251+
logger.info("Performing %s with %d iterations", operation, iterations)
109252

110-
if operation == "READ":
111-
logger.info("Starting READ operation")
253+
sent_descs = my_agent.deserialize_descs(recv_msg[8:])
112254

113-
# Read from file first
114-
execute_transfer(
115-
my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ"
255+
if operation == "READ":
256+
pipeline_reads(
257+
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
116258
)
117-
# Send to client
118-
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE")
119-
120259
elif operation == "WRITE":
121-
logger.info("Starting WRITE operation")
122-
123-
# Read from client first
124-
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ")
125-
# Write to storage
126-
execute_transfer(
127-
my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE"
260+
pipeline_writes(
261+
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
128262
)
129263

130264
# Send completion notification to initiator
131265
my_agent.send_notif(req_agent, b"COMPLETE")
132266

133-
logger.info("One transfer test complete.")
134267

135-
136-
def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
268+
def run_client(
269+
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations
270+
):
137271
logger.info("Client initialized, ready for local transfer test...")
138272

139273
# For sample purposes, write to and then read from local storage
140274
logger.info("Starting local transfer test...")
141-
execute_transfer(
142-
my_agent,
143-
nixl_mem_reg_descs.trim(),
144-
nixl_file_reg_descs.trim(),
145-
my_agent.name,
146-
"WRITE",
147-
)
148-
execute_transfer(
149-
my_agent,
150-
nixl_mem_reg_descs.trim(),
151-
nixl_file_reg_descs.trim(),
152-
my_agent.name,
153-
"READ",
154-
)
275+
276+
start_time = time.time()
277+
278+
for i in range(1, iterations):
279+
execute_transfer(
280+
my_agent,
281+
nixl_mem_reg_descs.trim(),
282+
nixl_file_reg_descs.trim(),
283+
my_agent.name,
284+
"WRITE",
285+
["GDS_MT"],
286+
)
287+
288+
elapsed = time.time() - start_time
289+
290+
logger.info("Time for %d WRITE iterations: %f seconds", iterations, elapsed)
291+
292+
start_time = time.time()
293+
294+
for i in range(1, iterations):
295+
execute_transfer(
296+
my_agent,
297+
nixl_mem_reg_descs.trim(),
298+
nixl_file_reg_descs.trim(),
299+
my_agent.name,
300+
"READ",
301+
["GDS_MT"],
302+
)
303+
304+
elapsed = time.time() - start_time
305+
306+
logger.info("Time for %d READ iterations: %f seconds", iterations, elapsed)
307+
155308
logger.info("Local transfer test complete")
156309

157310
logger.info("Starting remote transfer test...")
@@ -161,10 +314,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
161314
# For sample purposes, write to and then read from each target agent
162315
for target_agent in target_agents:
163316
remote_storage_transfer(
164-
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent
317+
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent, iterations
165318
)
166319
remote_storage_transfer(
167-
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent
320+
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent, iterations
168321
)
169322

170323
logger.info("Remote transfer test complete")
@@ -199,8 +352,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
199352
type=str,
200353
help="File containing list of target agents (only needed for client)",
201354
)
355+
parser.add_argument(
356+
"--iterations",
357+
type=int,
358+
default=100,
359+
help="Number of iterations for each transfer",
360+
)
202361
args = parser.parse_args()
203362

363+
mem = "DRAM"
364+
365+
if args.role == "client":
366+
mem = "VRAM"
367+
204368
my_agent = nsu.create_agent_with_plugins(args.name, args.port)
205369

206370
(
@@ -209,15 +373,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
209373
nixl_mem_reg_descs,
210374
nixl_file_reg_descs,
211375
) = nsu.setup_memory_and_files(
212-
my_agent, args.batch_size, args.buf_size, args.fileprefix
376+
my_agent, args.batch_size, args.buf_size, args.fileprefix, mem
213377
)
214378

215379
if args.role == "client":
216380
if not args.agents_file:
217381
parser.error("--agents_file is required when role is client")
218382
try:
219383
run_client(
220-
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, args.agents_file
384+
my_agent,
385+
nixl_mem_reg_descs,
386+
nixl_file_reg_descs,
387+
args.agents_file,
388+
args.iterations,
221389
)
222390
finally:
223391
nsu.cleanup_resources(

0 commit comments

Comments
 (0)