5656 default = 10 ,
5757 help = "Number of warmup iterations (default: 5)" ,
5858 )
59+ parser .add_argument (
60+ "--concurrency" ,
61+ type = int ,
62+ default = 1 ,
63+ help = "Number of concurrent RDMA operations (default: 1)" ,
64+ )
5965
6066 args = parser .parse_args ()
6167
@@ -85,13 +91,38 @@ def __init__(
8591 # Timing data storage
8692 self .timing_data = []
8793 self .size_data = []
94+ self .batch_timing_data = []
95+ self .batch_size_data = []
8896
8997 @endpoint
9098 async def set_other_actor (self , other_actor ):
9199 self .other_actor = other_actor
92100
93101 @endpoint
94- async def send (self , is_warmup = False ) -> None :
102+ async def send (self , is_warmup = False , concurrency : int = 1 ) -> None :
103+ # Track wall-clock time for the entire concurrent batch
104+ batch_start = time .time ()
105+
106+ tasks = []
107+ for _ in range (concurrency ):
108+ tasks .append (self ._send_single (is_warmup ))
109+ await asyncio .gather (* tasks )
110+
111+ batch_end = time .time ()
112+ batch_elapsed = batch_end - batch_start
113+
114+ if not is_warmup :
115+ batch_size = (
116+ sum (self .size_data [- concurrency :])
117+ if len (self .size_data ) >= concurrency
118+ else 0
119+ )
120+ self .batch_timing_data .append (batch_elapsed )
121+ self .batch_size_data .append (batch_size )
122+
123+ self .i += 1
124+
125+ async def _send_single (self , is_warmup = False ) -> None :
95126 shape = int (
96127 1024 * 1024 * self .size_mb / 4 * (0.5 * random .randint (1 , 3 ))
97128 ) # Random size with +/- 50% variation based on user size
@@ -104,7 +135,7 @@ async def send(self, is_warmup=False) -> None:
104135 # Critical validation - this should catch the null pointer issue
105136 assert (
106137 tensor_addr != 0
107- ), f"CRITICAL: Tensor has null pointer! Device: { device } , Shape: { shape } "
138+ ), f"CRITICAL: Tensor has null pointer! Device: { self . device } , Shape: { shape } "
108139 assert size_elem > 0 , f"CRITICAL: Tensor has zero size! Size: { size_elem } "
109140
110141 byte_view = tensor .view (torch .uint8 ).flatten ()
@@ -136,8 +167,6 @@ async def send(self, is_warmup=False) -> None:
136167 # cleanup
137168 await buffer .drop ()
138169
139- self .i += 1
140-
141170 @endpoint
142171 async def recv (self , rdma_buffer , shape , dtype , is_warmup ):
143172 # Create receiving tensor on the same device
@@ -167,15 +196,15 @@ async def recv(self, rdma_buffer, shape, dtype, is_warmup):
167196
168197 @endpoint
169198 async def print_statistics (self , calc_bwd : bool = False ):
170- """Calculate and print timing statistics"""
199+ """Calculate and print timing statistics for individual operations """
171200 if not self .timing_data :
172201 print ("No timing data collected!" )
173202 return
174203
175204 timings = self .timing_data
176205 sizes = self .size_data
177206
178- # Calculate statistics
207+ # Calculate statistics for individual operations
179208 avg_time = statistics .mean (timings )
180209 min_time = min (timings )
181210 max_time = max (timings )
@@ -184,7 +213,12 @@ async def print_statistics(self, calc_bwd: bool = False):
184213 avg_size = statistics .mean (sizes )
185214 total_data = sum (sizes )
186215
187- print ("TIMING RESULTS:" )
216+ device_type = self .device .upper () if self .device != "cpu" else "CPU"
217+ print ("\n " + "=" * 60 )
218+ print (f"RDMA { self .operation .upper ()} LOAD TEST RESULTS ({ device_type } )" )
219+ print ("=" * 60 )
220+
221+ print ("INDIVIDUAL OPERATION TIMING:" )
188222 print (f" Average time per operation: { avg_time * 1000 :.3f} ms" )
189223 print (f" Minimum time per operation: { min_time * 1000 :.3f} ms" )
190224 print (f" Maximum time per operation: { max_time * 1000 :.3f} ms" )
@@ -202,31 +236,78 @@ def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float:
202236 max_bandwidth = calc_bandwidth_gbps (avg_size , min_time )
203237 min_bandwidth = calc_bandwidth_gbps (avg_size , max_time )
204238
205- device_type = self .device .upper () if self .device != "cpu" else "CPU"
206-
207239 # Print results
208- print ("\n " + "=" * 60 )
209- print (f"RDMA { self .operation .upper ()} LOAD TEST RESULTS ({ device_type } )" )
210- print ("=" * 60 )
211240 print (f"Total iterations completed: { len (timings )} " )
212241 print (f"Average data per operation: { avg_size / (1024 * 1024 ):.1f} MB" )
213242 print (f"Total data transferred: { total_data / (1024 * 1024 ):.1f} MB" )
214243 print ()
215244
216- print ()
217- print ("BANDWIDTH RESULTS:" )
245+ print ("INDIVIDUAL OPERATION BANDWIDTH:" )
218246 print (f" Average bandwidth: { avg_bandwidth :.2f} Gbps" )
219247 print (f" Maximum bandwidth: { max_bandwidth :.2f} Gbps" )
220248 print (f" Minimum bandwidth: { min_bandwidth :.2f} Gbps" )
221249 print ("=" * 60 )
222250
251+ @endpoint
252+ async def print_batch_statistics (
253+ self , concurrency : int = 1 , total_elapsed_time : float = 0.0
254+ ):
255+ """Calculate and print batch-level statistics for concurrent operations"""
256+ if not self .batch_timing_data :
257+ print ("No batch timing data collected!" )
258+ return
259+
260+ batch_timings = self .batch_timing_data
261+ batch_sizes = self .batch_size_data
262+ total_data = sum (self .size_data )
263+
264+ avg_batch_time = statistics .mean (batch_timings )
265+ min_batch_time = min (batch_timings )
266+ max_batch_time = max (batch_timings )
267+ std_batch_time = (
268+ statistics .stdev (batch_timings ) if len (batch_timings ) > 1 else 0.0
269+ )
270+ avg_batch_size = statistics .mean (batch_sizes )
271+
272+ print ("\n CONCURRENT BATCH TIMING (wall-clock for all concurrent ops):" )
273+ print (f" Average batch time: { avg_batch_time * 1000 :.3f} ms" )
274+ print (f" Minimum batch time: { min_batch_time * 1000 :.3f} ms" )
275+ print (f" Maximum batch time: { max_batch_time * 1000 :.3f} ms" )
276+ print (f" Standard deviation: { std_batch_time * 1000 :.3f} ms" )
277+ print (f" Average data per batch: { avg_batch_size / (1024 * 1024 ):.1f} MB" )
278+
279+ # Calculate bandwidth (Gbps)
280+ def calc_bandwidth_gbps (size_bytes : int , time_seconds : float ) -> float :
281+ if time_seconds == 0 :
282+ return 0.0
283+ bits_transferred = size_bytes * 8
284+ return bits_transferred / (time_seconds * 1e9 )
285+
286+ avg_aggregate_bw = calc_bandwidth_gbps (avg_batch_size , avg_batch_time )
287+ max_aggregate_bw = calc_bandwidth_gbps (avg_batch_size , min_batch_time )
288+ min_aggregate_bw = calc_bandwidth_gbps (avg_batch_size , max_batch_time )
289+
290+ print (f"\n AGGREGATE BANDWIDTH (concurrency={ concurrency } ):" )
291+ print (f" Average aggregate bandwidth: { avg_aggregate_bw :.2f} Gbps" )
292+ print (f" Maximum aggregate bandwidth: { max_aggregate_bw :.2f} Gbps" )
293+ print (f" Minimum aggregate bandwidth: { min_aggregate_bw :.2f} Gbps" )
294+
295+ total_throughput = calc_bandwidth_gbps (total_data , total_elapsed_time )
296+ print ("\n TOTAL SUSTAINED THROUGHPUT:" )
297+ print (f" Total wall-clock time: { total_elapsed_time :.3f} s" )
298+ print (f" Total data transferred: { total_data / (1024 * 1024 ):.1f} MB" )
299+ print (f" Sustained throughput: { total_throughput :.2f} Gbps" )
300+ if concurrency > 1 :
301+ print (f" (Accounts for { concurrency } x concurrent overlapping operations)" )
302+
223303
224304async def main (
225305 devices : list [str ],
226306 iterations : int = 100 ,
227307 operation : str = "write" ,
228308 size_mb : int = 64 ,
229309 warmup_iterations : int = 10 ,
310+ concurrency : int = 1 ,
230311):
231312 # Adjust GPU allocation based on the device types
232313 device_0 , device_1 = devices [0 ], devices [1 ]
@@ -245,16 +326,20 @@ async def main(
245326 await actor_0 .set_other_actor .call (actor_1 )
246327
247328 for i in range (warmup_iterations ):
248- await actor_0 .send .call (is_warmup = True )
329+ await actor_0 .send .call (is_warmup = True , concurrency = concurrency )
249330
331+ total_start_time = time .time ()
250332 for i in range (iterations ):
251- await actor_0 .send .call ()
333+ await actor_0 .send .call (concurrency = concurrency )
334+ total_end_time = time .time ()
335+ total_elapsed_time = total_end_time - total_start_time
252336
253- # Have both actors print their statistics
254- print ("\n === ACTOR 0 (Create Buffer) STATISTICS ===" )
255- await actor_0 .print_statistics .call ()
337+ # Actor 0: Print batch statistics (concurrency orchestration)
338+ await actor_0 .print_batch_statistics .call (
339+ concurrency = concurrency , total_elapsed_time = total_elapsed_time
340+ )
256341
257- print ( " \n === ACTOR 1 (Create Buffer+Transmit) STATISTICS ===" )
342+ # Actor 1: Print individual RDMA transfer statistics
258343 await actor_1 .print_statistics .call (calc_bwd = True )
259344
260345 await mesh_0 .stop ()
@@ -313,5 +398,6 @@ async def main(
313398 args .operation ,
314399 args .size ,
315400 args .warmup_iterations ,
401+ args .concurrency ,
316402 )
317403 )
0 commit comments