1818Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes.
1919"""
2020
21+ import concurrent .futures
2122import time
2223
2324import nixl_storage_utils as nsu
2728logger = get_logger (__name__ )
2829
2930
30- def execute_transfer (my_agent , local_descs , remote_descs , remote_name , operation ):
31- handle = my_agent .initialize_xfer (operation , local_descs , remote_descs , remote_name )
31+ def execute_transfer (
32+ my_agent , local_descs , remote_descs , remote_name , operation , use_backends = []
33+ ):
34+ handle = my_agent .initialize_xfer (
35+ operation , local_descs , remote_descs , remote_name , backends = use_backends
36+ )
3237 my_agent .transfer (handle )
3338 nsu .wait_for_transfer (my_agent , handle )
3439 my_agent .release_xfer_handle (handle )
3540
3641
37- def remote_storage_transfer (my_agent , my_mem_descs , operation , remote_agent_name ):
42+ def remote_storage_transfer (
43+ my_agent , my_mem_descs , operation , remote_agent_name , iterations
44+ ):
3845 """Initiate remote memory transfer."""
3946 if operation != "READ" and operation != "WRITE" :
4047 logger .error ("Invalid operation, exiting" )
@@ -45,14 +52,24 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name
4552 else :
4653 operation = b"READ"
4754
55+ iterations_str = bytes (f"{ iterations :04d} " , "utf-8" )
4856 # Send the descriptors that you want to read into or write from
49- logger .info (f"Sending { operation } request to { remote_agent_name } " )
57+ logger .info (
58+ "Sending %s request to %s" , operation .decode ("utf-8" ), remote_agent_name
59+ )
5060 test_descs_str = my_agent .get_serialized_descs (my_mem_descs )
51- my_agent .send_notif (remote_agent_name , operation + test_descs_str )
61+
62+ start_time = time .time ()
63+
64+ my_agent .send_notif (remote_agent_name , operation + iterations_str + test_descs_str )
5265
5366 while not my_agent .check_remote_xfer_done (remote_agent_name , b"COMPLETE" ):
5467 continue
5568
69+ elapsed = time .time () - start_time
70+
71+ logger .info ("Time for %d iterations: %f seconds" , iterations , elapsed )
72+
5673
5774def connect_to_agents (my_agent , agents_file ):
5875 target_agents = []
@@ -66,26 +83,154 @@ def connect_to_agents(my_agent, agents_file):
6683 my_agent .fetch_remote_metadata (parts [0 ], parts [1 ], int (parts [2 ]))
6784
6885 while my_agent .check_remote_metadata (parts [0 ]) is False :
69- logger .info (f "Waiting for remote metadata for { parts [ 0 ] } ..." )
86+ logger .info ("Waiting for remote metadata for %s ..." , parts [ 0 ] )
7087 time .sleep (0.2 )
7188
72- logger .info (f "Remote metadata for { parts [0 ]} fetched" )
89+ logger .info ("Remote metadata for %s fetched" , parts [0 ])
7390 else :
74- logger .error (f "Invalid line in { agents_file } : { line } " )
91+ logger .error ("Invalid line in %s: %s" , agents_file , line )
7592 exit (- 1 )
7693
7794 logger .info ("All remote metadata fetched" )
7895
7996 return target_agents
8097
8198
99+ def pipeline_reads (
100+ my_agent , req_agent , my_mem_descs , my_file_descs , sent_descs , iterations
101+ ):
102+ with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
103+ n = 0
104+ s = 0
105+ futures = []
106+
107+ while n < iterations or s < iterations :
108+ if s == 0 :
109+ futures .append (
110+ executor .submit (
111+ execute_transfer ,
112+ my_agent ,
113+ my_mem_descs ,
114+ my_file_descs ,
115+ my_agent .name ,
116+ "READ" ,
117+ )
118+ )
119+ s += 1
120+ continue
121+
122+ if s == iterations :
123+ futures .append (
124+ executor .submit (
125+ execute_transfer ,
126+ my_agent ,
127+ my_mem_descs ,
128+ sent_descs ,
129+ req_agent ,
130+ "WRITE" ,
131+ )
132+ )
133+ n += 1
134+ continue
135+
136+ # Do two storage and network in parallel
137+ futures .append (
138+ executor .submit (
139+ execute_transfer ,
140+ my_agent ,
141+ my_mem_descs ,
142+ my_file_descs ,
143+ my_agent .name ,
144+ "READ" ,
145+ )
146+ )
147+ futures .append (
148+ executor .submit (
149+ execute_transfer ,
150+ my_agent ,
151+ my_mem_descs ,
152+ sent_descs ,
153+ req_agent ,
154+ "WRITE" ,
155+ )
156+ )
157+ s += 1
158+ n += 1
159+
160+ _ , not_done = concurrent .futures .wait (
161+ futures , return_when = concurrent .futures .ALL_COMPLETED
162+ )
163+ assert not not_done
164+
165+
166+ def pipeline_writes (
167+ my_agent , req_agent , my_mem_descs , my_file_descs , sent_descs , iterations
168+ ):
169+ with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
170+ n = 0
171+ s = 1
172+ futures = []
173+
174+ futures .append (
175+ executor .submit (
176+ execute_transfer ,
177+ my_agent ,
178+ my_mem_descs ,
179+ sent_descs ,
180+ req_agent ,
181+ "READ" ,
182+ )
183+ )
184+ while n < iterations or s < iterations :
185+ if s == iterations :
186+ futures .append (
187+ executor .submit (
188+ execute_transfer ,
189+ my_agent ,
190+ my_mem_descs ,
191+ my_file_descs ,
192+ my_agent .name ,
193+ "WRITE" ,
194+ )
195+ )
196+ n += 1
197+ continue
198+
199+ # Do two storage and network in parallel
200+ futures .append (
201+ executor .submit (
202+ execute_transfer ,
203+ my_agent ,
204+ my_mem_descs ,
205+ sent_descs ,
206+ req_agent ,
207+ "READ" ,
208+ )
209+ )
210+ futures .append (
211+ executor .submit (
212+ execute_transfer ,
213+ my_agent ,
214+ my_mem_descs ,
215+ my_file_descs ,
216+ my_agent .name ,
217+ "WRITE" ,
218+ )
219+ )
220+ s += 1
221+ n += 1
222+
223+ _ , not_done = concurrent .futures .wait (
224+ futures , return_when = concurrent .futures .ALL_COMPLETED
225+ )
226+ assert not not_done
227+
228+
82229def handle_remote_transfer_request (my_agent , my_mem_descs , my_file_descs ):
83230 """Handle remote memory and storage transfers as target."""
84231 # Wait for initiator to send list of memory descriptors
85232 notifs = my_agent .get_new_notifs ()
86233
87- logger .info ("Waiting for a remote transfer request..." )
88-
89234 while len (notifs ) == 0 :
90235 notifs = my_agent .get_new_notifs ()
91236
@@ -101,57 +246,65 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
101246 logger .error ("Invalid operation, exiting" )
102247 exit (- 1 )
103248
104- sent_descs = my_agent . deserialize_descs (recv_msg [4 :])
249+ iterations = int (recv_msg [4 :8 ])
105250
106- logger .info ("Checking to ensure metadata is loaded..." )
107- while my_agent .check_remote_metadata (req_agent , sent_descs ) is False :
108- continue
251+ logger .info ("Performing %s with %d iterations" , operation , iterations )
109252
110- if operation == "READ" :
111- logger .info ("Starting READ operation" )
253+ sent_descs = my_agent .deserialize_descs (recv_msg [8 :])
112254
113- # Read from file first
114- execute_transfer (
115- my_agent , my_mem_descs , my_file_descs , my_agent . name , "READ"
255+ if operation == "READ" :
256+ pipeline_reads (
257+ my_agent , req_agent , my_mem_descs , my_file_descs , sent_descs , iterations
116258 )
117- # Send to client
118- execute_transfer (my_agent , my_mem_descs , sent_descs , req_agent , "WRITE" )
119-
120259 elif operation == "WRITE" :
121- logger .info ("Starting WRITE operation" )
122-
123- # Read from client first
124- execute_transfer (my_agent , my_mem_descs , sent_descs , req_agent , "READ" )
125- # Write to storage
126- execute_transfer (
127- my_agent , my_mem_descs , my_file_descs , my_agent .name , "WRITE"
260+ pipeline_writes (
261+ my_agent , req_agent , my_mem_descs , my_file_descs , sent_descs , iterations
128262 )
129263
130264 # Send completion notification to initiator
131265 my_agent .send_notif (req_agent , b"COMPLETE" )
132266
133- logger .info ("One transfer test complete." )
134267
135-
136- def run_client (my_agent , nixl_mem_reg_descs , nixl_file_reg_descs , agents_file ):
268+ def run_client (
269+ my_agent , nixl_mem_reg_descs , nixl_file_reg_descs , agents_file , iterations
270+ ):
137271 logger .info ("Client initialized, ready for local transfer test..." )
138272
139273 # For sample purposes, write to and then read from local storage
140274 logger .info ("Starting local transfer test..." )
141- execute_transfer (
142- my_agent ,
143- nixl_mem_reg_descs .trim (),
144- nixl_file_reg_descs .trim (),
145- my_agent .name ,
146- "WRITE" ,
147- )
148- execute_transfer (
149- my_agent ,
150- nixl_mem_reg_descs .trim (),
151- nixl_file_reg_descs .trim (),
152- my_agent .name ,
153- "READ" ,
154- )
275+
276+ start_time = time .time ()
277+
278+ for i in range (1 , iterations ):
279+ execute_transfer (
280+ my_agent ,
281+ nixl_mem_reg_descs .trim (),
282+ nixl_file_reg_descs .trim (),
283+ my_agent .name ,
284+ "WRITE" ,
285+ ["GDS_MT" ],
286+ )
287+
288+ elapsed = time .time () - start_time
289+
290+ logger .info ("Time for %d WRITE iterations: %f seconds" , iterations , elapsed )
291+
292+ start_time = time .time ()
293+
294+ for i in range (1 , iterations ):
295+ execute_transfer (
296+ my_agent ,
297+ nixl_mem_reg_descs .trim (),
298+ nixl_file_reg_descs .trim (),
299+ my_agent .name ,
300+ "READ" ,
301+ ["GDS_MT" ],
302+ )
303+
304+ elapsed = time .time () - start_time
305+
306+ logger .info ("Time for %d READ iterations: %f seconds" , iterations , elapsed )
307+
155308 logger .info ("Local transfer test complete" )
156309
157310 logger .info ("Starting remote transfer test..." )
@@ -161,10 +314,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
161314 # For sample purposes, write to and then read from each target agent
162315 for target_agent in target_agents :
163316 remote_storage_transfer (
164- my_agent , nixl_mem_reg_descs .trim (), "WRITE" , target_agent
317+ my_agent , nixl_mem_reg_descs .trim (), "WRITE" , target_agent , iterations
165318 )
166319 remote_storage_transfer (
167- my_agent , nixl_mem_reg_descs .trim (), "READ" , target_agent
320+ my_agent , nixl_mem_reg_descs .trim (), "READ" , target_agent , iterations
168321 )
169322
170323 logger .info ("Remote transfer test complete" )
@@ -199,8 +352,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
199352 type = str ,
200353 help = "File containing list of target agents (only needed for client)" ,
201354 )
355+ parser .add_argument (
356+ "--iterations" ,
357+ type = int ,
358+ default = 100 ,
359+ help = "Number of iterations for each transfer" ,
360+ )
202361 args = parser .parse_args ()
203362
363+ mem = "DRAM"
364+
365+ if args .role == "client" :
366+ mem = "VRAM"
367+
204368 my_agent = nsu .create_agent_with_plugins (args .name , args .port )
205369
206370 (
@@ -209,15 +373,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
209373 nixl_mem_reg_descs ,
210374 nixl_file_reg_descs ,
211375 ) = nsu .setup_memory_and_files (
212- my_agent , args .batch_size , args .buf_size , args .fileprefix
376+ my_agent , args .batch_size , args .buf_size , args .fileprefix , mem
213377 )
214378
215379 if args .role == "client" :
216380 if not args .agents_file :
217381 parser .error ("--agents_file is required when role is client" )
218382 try :
219383 run_client (
220- my_agent , nixl_mem_reg_descs , nixl_file_reg_descs , args .agents_file
384+ my_agent ,
385+ nixl_mem_reg_descs ,
386+ nixl_file_reg_descs ,
387+ args .agents_file ,
388+ args .iterations ,
221389 )
222390 finally :
223391 nsu .cleanup_resources (
0 commit comments