Skip to content

Commit 93b653a

Browse files
dstaay-fbmeta-codesync[bot]
authored andcommitted
Update load test to support concurrency (#1944)
Summary: Pull Request resolved: #1944 Update script to support concurrency, with relevant benchmarks: buck run @//mode/dev-nosan //monarch/python/tests:rdma_load_test -- --device cuda:0 cuda:1 --operation write --iterations 5 --size 500 --expandable-segments true --concurrency 4 sample output ``` ================================================================== CONCURRENT BATCH TIMING (wall-clock for all concurrent ops): Average batch time: 48.681 ms Minimum batch time: 25.463 ms Maximum batch time: 230.379 ms Standard deviation: 20.382 ms Average data per batch: 1982.5 MB AGGREGATE BANDWIDTH (concurrency=4): Average aggregate bandwidth: 341.62 Gbps Maximum aggregate bandwidth: 653.13 Gbps Minimum aggregate bandwidth: 72.19 Gbps TOTAL SUSTAINED THROUGHPUT: Total wall-clock time: 5.094 s Total data transferred: 198250.0 MB Sustained throughput: 326.47 Gbps (Accounts for 4x concurrent overlapping operations) ============================================================ RDMA WRITE LOAD TEST RESULTS (CUDA:1) ============================================================ INDIVIDUAL OPERATION TIMING: Average time per operation: 29.031 ms Minimum time per operation: 6.103 ms Maximum time per operation: 191.391 ms Standard deviation: 19.436 ms Total iterations completed: 400 Average data per operation: 495.6 MB Total data transferred: 198250.0 MB INDIVIDUAL OPERATION BANDWIDTH: Average bandwidth: 143.21 Gbps Maximum bandwidth: 681.26 Gbps Minimum bandwidth: 21.72 Gbps ``` Reviewed By: casteryh Differential Revision: D87475053 fbshipit-source-id: 6e810b8623f29150025c31d21071afd9853f88eb
1 parent 12ce58c commit 93b653a

File tree

1 file changed

+106
-20
lines changed

1 file changed

+106
-20
lines changed

python/tests/rdma_load_test.py

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@
5656
default=10,
5757
help="Number of warmup iterations (default: 5)",
5858
)
59+
parser.add_argument(
60+
"--concurrency",
61+
type=int,
62+
default=1,
63+
help="Number of concurrent RDMA operations (default: 1)",
64+
)
5965

6066
args = parser.parse_args()
6167

@@ -85,13 +91,38 @@ def __init__(
8591
# Timing data storage
8692
self.timing_data = []
8793
self.size_data = []
94+
self.batch_timing_data = []
95+
self.batch_size_data = []
8896

8997
@endpoint
9098
async def set_other_actor(self, other_actor):
9199
self.other_actor = other_actor
92100

93101
@endpoint
94-
async def send(self, is_warmup=False) -> None:
102+
async def send(self, is_warmup=False, concurrency: int = 1) -> None:
103+
# Track wall-clock time for the entire concurrent batch
104+
batch_start = time.time()
105+
106+
tasks = []
107+
for _ in range(concurrency):
108+
tasks.append(self._send_single(is_warmup))
109+
await asyncio.gather(*tasks)
110+
111+
batch_end = time.time()
112+
batch_elapsed = batch_end - batch_start
113+
114+
if not is_warmup:
115+
batch_size = (
116+
sum(self.size_data[-concurrency:])
117+
if len(self.size_data) >= concurrency
118+
else 0
119+
)
120+
self.batch_timing_data.append(batch_elapsed)
121+
self.batch_size_data.append(batch_size)
122+
123+
self.i += 1
124+
125+
async def _send_single(self, is_warmup=False) -> None:
95126
shape = int(
96127
1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3))
97128
) # Random size with +/- 50% variation based on user size
@@ -104,7 +135,7 @@ async def send(self, is_warmup=False) -> None:
104135
# Critical validation - this should catch the null pointer issue
105136
assert (
106137
tensor_addr != 0
107-
), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}"
138+
), f"CRITICAL: Tensor has null pointer! Device: {self.device}, Shape: {shape}"
108139
assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}"
109140

110141
byte_view = tensor.view(torch.uint8).flatten()
@@ -136,8 +167,6 @@ async def send(self, is_warmup=False) -> None:
136167
# cleanup
137168
await buffer.drop()
138169

139-
self.i += 1
140-
141170
@endpoint
142171
async def recv(self, rdma_buffer, shape, dtype, is_warmup):
143172
# Create receiving tensor on the same device
@@ -167,15 +196,15 @@ async def recv(self, rdma_buffer, shape, dtype, is_warmup):
167196

168197
@endpoint
169198
async def print_statistics(self, calc_bwd: bool = False):
170-
"""Calculate and print timing statistics"""
199+
"""Calculate and print timing statistics for individual operations"""
171200
if not self.timing_data:
172201
print("No timing data collected!")
173202
return
174203

175204
timings = self.timing_data
176205
sizes = self.size_data
177206

178-
# Calculate statistics
207+
# Calculate statistics for individual operations
179208
avg_time = statistics.mean(timings)
180209
min_time = min(timings)
181210
max_time = max(timings)
@@ -184,7 +213,12 @@ async def print_statistics(self, calc_bwd: bool = False):
184213
avg_size = statistics.mean(sizes)
185214
total_data = sum(sizes)
186215

187-
print("TIMING RESULTS:")
216+
device_type = self.device.upper() if self.device != "cpu" else "CPU"
217+
print("\n" + "=" * 60)
218+
print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})")
219+
print("=" * 60)
220+
221+
print("INDIVIDUAL OPERATION TIMING:")
188222
print(f" Average time per operation: {avg_time * 1000:.3f} ms")
189223
print(f" Minimum time per operation: {min_time * 1000:.3f} ms")
190224
print(f" Maximum time per operation: {max_time * 1000:.3f} ms")
@@ -202,31 +236,78 @@ def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float:
202236
max_bandwidth = calc_bandwidth_gbps(avg_size, min_time)
203237
min_bandwidth = calc_bandwidth_gbps(avg_size, max_time)
204238

205-
device_type = self.device.upper() if self.device != "cpu" else "CPU"
206-
207239
# Print results
208-
print("\n" + "=" * 60)
209-
print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})")
210-
print("=" * 60)
211240
print(f"Total iterations completed: {len(timings)}")
212241
print(f"Average data per operation: {avg_size / (1024*1024):.1f} MB")
213242
print(f"Total data transferred: {total_data / (1024*1024):.1f} MB")
214243
print()
215244

216-
print()
217-
print("BANDWIDTH RESULTS:")
245+
print("INDIVIDUAL OPERATION BANDWIDTH:")
218246
print(f" Average bandwidth: {avg_bandwidth:.2f} Gbps")
219247
print(f" Maximum bandwidth: {max_bandwidth:.2f} Gbps")
220248
print(f" Minimum bandwidth: {min_bandwidth:.2f} Gbps")
221249
print("=" * 60)
222250

251+
@endpoint
252+
async def print_batch_statistics(
253+
self, concurrency: int = 1, total_elapsed_time: float = 0.0
254+
):
255+
"""Calculate and print batch-level statistics for concurrent operations"""
256+
if not self.batch_timing_data:
257+
print("No batch timing data collected!")
258+
return
259+
260+
batch_timings = self.batch_timing_data
261+
batch_sizes = self.batch_size_data
262+
total_data = sum(self.size_data)
263+
264+
avg_batch_time = statistics.mean(batch_timings)
265+
min_batch_time = min(batch_timings)
266+
max_batch_time = max(batch_timings)
267+
std_batch_time = (
268+
statistics.stdev(batch_timings) if len(batch_timings) > 1 else 0.0
269+
)
270+
avg_batch_size = statistics.mean(batch_sizes)
271+
272+
print("\nCONCURRENT BATCH TIMING (wall-clock for all concurrent ops):")
273+
print(f" Average batch time: {avg_batch_time * 1000:.3f} ms")
274+
print(f" Minimum batch time: {min_batch_time * 1000:.3f} ms")
275+
print(f" Maximum batch time: {max_batch_time * 1000:.3f} ms")
276+
print(f" Standard deviation: {std_batch_time * 1000:.3f} ms")
277+
print(f" Average data per batch: {avg_batch_size / (1024*1024):.1f} MB")
278+
279+
# Calculate bandwidth (Gbps)
280+
def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float:
281+
if time_seconds == 0:
282+
return 0.0
283+
bits_transferred = size_bytes * 8
284+
return bits_transferred / (time_seconds * 1e9)
285+
286+
avg_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, avg_batch_time)
287+
max_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, min_batch_time)
288+
min_aggregate_bw = calc_bandwidth_gbps(avg_batch_size, max_batch_time)
289+
290+
print(f"\nAGGREGATE BANDWIDTH (concurrency={concurrency}):")
291+
print(f" Average aggregate bandwidth: {avg_aggregate_bw:.2f} Gbps")
292+
print(f" Maximum aggregate bandwidth: {max_aggregate_bw:.2f} Gbps")
293+
print(f" Minimum aggregate bandwidth: {min_aggregate_bw:.2f} Gbps")
294+
295+
total_throughput = calc_bandwidth_gbps(total_data, total_elapsed_time)
296+
print("\nTOTAL SUSTAINED THROUGHPUT:")
297+
print(f" Total wall-clock time: {total_elapsed_time:.3f} s")
298+
print(f" Total data transferred: {total_data / (1024*1024):.1f} MB")
299+
print(f" Sustained throughput: {total_throughput:.2f} Gbps")
300+
if concurrency > 1:
301+
print(f" (Accounts for {concurrency}x concurrent overlapping operations)")
302+
223303

224304
async def main(
225305
devices: list[str],
226306
iterations: int = 100,
227307
operation: str = "write",
228308
size_mb: int = 64,
229309
warmup_iterations: int = 10,
310+
concurrency: int = 1,
230311
):
231312
# Adjust GPU allocation based on the device types
232313
device_0, device_1 = devices[0], devices[1]
@@ -245,16 +326,20 @@ async def main(
245326
await actor_0.set_other_actor.call(actor_1)
246327

247328
for i in range(warmup_iterations):
248-
await actor_0.send.call(is_warmup=True)
329+
await actor_0.send.call(is_warmup=True, concurrency=concurrency)
249330

331+
total_start_time = time.time()
250332
for i in range(iterations):
251-
await actor_0.send.call()
333+
await actor_0.send.call(concurrency=concurrency)
334+
total_end_time = time.time()
335+
total_elapsed_time = total_end_time - total_start_time
252336

253-
# Have both actors print their statistics
254-
print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===")
255-
await actor_0.print_statistics.call()
337+
# Actor 0: Print batch statistics (concurrency orchestration)
338+
await actor_0.print_batch_statistics.call(
339+
concurrency=concurrency, total_elapsed_time=total_elapsed_time
340+
)
256341

257-
print("\n=== ACTOR 1 (Create Buffer+Transmit) STATISTICS ===")
342+
# Actor 1: Print individual RDMA transfer statistics
258343
await actor_1.print_statistics.call(calc_bwd=True)
259344

260345
await mesh_0.stop()
@@ -313,5 +398,6 @@ async def main(
313398
args.operation,
314399
args.size,
315400
args.warmup_iterations,
401+
args.concurrency,
316402
)
317403
)

0 commit comments

Comments
 (0)