|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import argparse |
| 8 | +import asyncio |
| 9 | +import os |
| 10 | +import random |
| 11 | +import statistics |
| 12 | +import time |
| 13 | + |
| 14 | +# parse up front to extract env variables. |
| 15 | +args = None |
| 16 | +if __name__ == "__main__": |
| 17 | + parser = argparse.ArgumentParser( |
| 18 | + description="RDMA Test with configurable parameters" |
| 19 | + ) |
| 20 | + parser.add_argument( |
| 21 | + "--iterations", |
| 22 | + type=int, |
| 23 | + default=100, |
| 24 | + help="Number of test iterations (default: 100)", |
| 25 | + ) |
| 26 | + parser.add_argument( |
| 27 | + "--device", |
| 28 | + type=str, |
| 29 | + default="cpu", |
| 30 | + help="Device: cpu or cuda:X where X is 0-7 (default: cpu)", |
| 31 | + ) |
| 32 | + parser.add_argument( |
| 33 | + "--operation", |
| 34 | + choices=["write", "read", "ping-pong"], |
| 35 | + default="write", |
| 36 | + help="RDMA operation type: write, read, or ping-pong (default: write)", |
| 37 | + ) |
| 38 | + parser.add_argument( |
| 39 | + "--size", |
| 40 | + type=int, |
| 41 | + default=64, |
| 42 | + help="Data size per operation in MB (default: 64, must be multiple of 4)", |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + "--expandable-segments", |
| 46 | + type=str, |
| 47 | + choices=["true", "false"], |
| 48 | + default="true", |
| 49 | + help="Enable/disable PyTorch CUDA expandable segments (default: true)", |
| 50 | + ) |
| 51 | + |
| 52 | + args = parser.parse_args() |
| 53 | + |
| 54 | +# Set expandable segments environment variable based on CLI argument |
| 55 | +if args and args.expandable_segments == "false": |
| 56 | + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False" |
| 57 | +else: |
| 58 | + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| 59 | + |
| 60 | + |
| 61 | +# pyre-ignore |
| 62 | +import torch |
| 63 | +from monarch.actor import Actor, endpoint, this_host |
| 64 | +from monarch.tensor_engine import RDMABuffer |
| 65 | + |
| 66 | + |
| 67 | +class RDMATest(Actor): |
| 68 | + def __init__( |
| 69 | + self, device: str = "cpu", operation: str = "write", size_mb: int = 64 |
| 70 | + ) -> None: |
| 71 | + self.other_actor = None |
| 72 | + self.i = 0 |
| 73 | + self.device = device |
| 74 | + self.operation = operation |
| 75 | + self.size_mb = size_mb |
| 76 | + |
| 77 | + # Timing data storage |
| 78 | + self.timing_data = [] |
| 79 | + self.size_data = [] |
| 80 | + |
| 81 | + @endpoint |
| 82 | + def set_other_actor(self, other_actor): |
| 83 | + self.other_actor = other_actor |
| 84 | + |
| 85 | + @endpoint |
| 86 | + async def send(self) -> None: |
| 87 | + shape = int( |
| 88 | + 1024 * 1024 * self.size_mb / 4 * (0.5 * random.randint(1, 3)) |
| 89 | + ) # Random size with +/- 50% variation based on user size |
| 90 | + |
| 91 | + # Use the device string directly |
| 92 | + tensor = torch.rand(shape, dtype=torch.float32, device=self.device) |
| 93 | + size_elem = tensor.numel() * tensor.element_size() |
| 94 | + tensor_addr = tensor.data_ptr() |
| 95 | + |
| 96 | + # Critical validation - this should catch the null pointer issue |
| 97 | + assert ( |
| 98 | + tensor_addr != 0 |
| 99 | + ), f"CRITICAL: Tensor has null pointer! Device: {device}, Shape: {shape}" |
| 100 | + assert size_elem > 0, f"CRITICAL: Tensor has zero size! Size: {size_elem}" |
| 101 | + |
| 102 | + byte_view = tensor.view(torch.uint8).flatten() |
| 103 | + # Validate byte_view too |
| 104 | + byte_view_addr = byte_view.data_ptr() |
| 105 | + assert ( |
| 106 | + byte_view_addr != 0 |
| 107 | + ), f"CRITICAL: Byte view has null pointer! Original addr: 0x{tensor_addr:x}" |
| 108 | + assert ( |
| 109 | + byte_view_addr == tensor_addr |
| 110 | + ), f"CRITICAL: Address mismatch! Tensor: 0x{tensor_addr:x}, ByteView: 0x{byte_view_addr:x}" |
| 111 | + |
| 112 | + execution_start = time.time() |
| 113 | + buffer = RDMABuffer(byte_view) |
| 114 | + execution_end = time.time() |
| 115 | + elapsed = execution_end - execution_start |
| 116 | + |
| 117 | + # Store timing and size data in this actor |
| 118 | + size_elem = torch.numel(tensor) * tensor.element_size() |
| 119 | + self.timing_data.append(elapsed) |
| 120 | + self.size_data.append(size_elem) |
| 121 | + buffer_size = buffer.size() |
| 122 | + assert buffer_size == size_elem, f"{buffer_size=} != {size_elem=}" |
| 123 | + |
| 124 | + # Call recv - timing happens there |
| 125 | + await self.other_actor.recv.call( |
| 126 | + buffer, tensor.shape, tensor.dtype, self.device |
| 127 | + ) |
| 128 | + |
| 129 | + # cleanup |
| 130 | + await buffer.drop() |
| 131 | + |
| 132 | + self.i += 1 |
| 133 | + |
| 134 | + @endpoint |
| 135 | + async def recv(self, rdma_buffer, shape, dtype, device): |
| 136 | + # Create receiving tensor on the same device |
| 137 | + tensor = torch.rand(shape, dtype=dtype, device=device) |
| 138 | + byte_view = tensor.view(torch.uint8).flatten() |
| 139 | + |
| 140 | + execution_start = time.time() |
| 141 | + |
| 142 | + if self.operation == "write": |
| 143 | + await rdma_buffer.write_from(byte_view, timeout=5) |
| 144 | + elif self.operation == "read": |
| 145 | + await rdma_buffer.read_into(byte_view, timeout=5) |
| 146 | + elif self.operation == "ping-pong": |
| 147 | + if self.i % 2 == 0: |
| 148 | + await rdma_buffer.write_from(byte_view, timeout=5) |
| 149 | + else: |
| 150 | + await rdma_buffer.read_into(byte_view, timeout=5) |
| 151 | + |
| 152 | + execution_end = time.time() |
| 153 | + elapsed = execution_end - execution_start |
| 154 | + |
| 155 | + # Store timing and size data in this actor |
| 156 | + size_elem = torch.numel(tensor) * tensor.element_size() |
| 157 | + self.timing_data.append(elapsed) |
| 158 | + self.size_data.append(size_elem) |
| 159 | + |
| 160 | + @endpoint |
| 161 | + async def print_statistics(self, calc_bwd: bool = False): |
| 162 | + """Calculate and print timing statistics""" |
| 163 | + if not self.timing_data: |
| 164 | + print("No timing data collected!") |
| 165 | + return |
| 166 | + |
| 167 | + timings = self.timing_data |
| 168 | + sizes = self.size_data |
| 169 | + |
| 170 | + # Calculate statistics |
| 171 | + avg_time = statistics.mean(timings) |
| 172 | + min_time = min(timings) |
| 173 | + max_time = max(timings) |
| 174 | + std_time = statistics.stdev(timings) if len(timings) > 1 else 0.0 |
| 175 | + |
| 176 | + avg_size = statistics.mean(sizes) |
| 177 | + total_data = sum(sizes) |
| 178 | + |
| 179 | + print("TIMING RESULTS:") |
| 180 | + print(f" Average time per operation: {avg_time * 1000:.3f} ms") |
| 181 | + print(f" Minimum time per operation: {min_time * 1000:.3f} ms") |
| 182 | + print(f" Maximum time per operation: {max_time * 1000:.3f} ms") |
| 183 | + print(f" Standard deviation: {std_time * 1000:.3f} ms") |
| 184 | + |
| 185 | + if calc_bwd: |
| 186 | + # Calculate bandwidth (Gbps) |
| 187 | + def calc_bandwidth_gbps(size_bytes: int, time_seconds: float) -> float: |
| 188 | + if time_seconds == 0: |
| 189 | + return 0.0 |
| 190 | + bits_transferred = size_bytes * 8 |
| 191 | + return bits_transferred / (time_seconds * 1e9) |
| 192 | + |
| 193 | + avg_bandwidth = calc_bandwidth_gbps(avg_size, avg_time) |
| 194 | + max_bandwidth = calc_bandwidth_gbps(avg_size, min_time) |
| 195 | + min_bandwidth = calc_bandwidth_gbps(avg_size, max_time) |
| 196 | + |
| 197 | + device_type = self.device.upper() if self.device != "cpu" else "CPU" |
| 198 | + |
| 199 | + # Print results |
| 200 | + print("\n" + "=" * 60) |
| 201 | + print(f"RDMA {self.operation.upper()} LOAD TEST RESULTS ({device_type})") |
| 202 | + print("=" * 60) |
| 203 | + print(f"Total iterations completed: {len(timings)}") |
| 204 | + print(f"Average data per operation: {avg_size / (1024*1024):.1f} MB") |
| 205 | + print(f"Total data transferred: {total_data / (1024*1024):.1f} MB") |
| 206 | + print() |
| 207 | + |
| 208 | + print() |
| 209 | + print("BANDWIDTH RESULTS:") |
| 210 | + print(f" Average bandwidth: {avg_bandwidth:.2f} Gbps") |
| 211 | + print(f" Maximum bandwidth: {max_bandwidth:.2f} Gbps") |
| 212 | + print(f" Minimum bandwidth: {min_bandwidth:.2f} Gbps") |
| 213 | + print("=" * 60) |
| 214 | + |
| 215 | + |
| 216 | +async def main( |
| 217 | + device: str = "cpu", |
| 218 | + iterations: int = 100, |
| 219 | + operation: str = "write", |
| 220 | + size_mb: int = 64, |
| 221 | +): |
| 222 | + # Adjust GPU allocation based on the device type |
| 223 | + use_cuda = device.startswith("cuda:") |
| 224 | + gpu_config = {"gpus": 1} if use_cuda else {"cpus": 1} |
| 225 | + |
| 226 | + mesh_0 = this_host().spawn_procs(per_host=gpu_config) |
| 227 | + actor_0 = mesh_0.spawn("rdma_test", RDMATest, device, operation, size_mb) |
| 228 | + |
| 229 | + mesh_1 = this_host().spawn_procs(per_host=gpu_config) |
| 230 | + actor_1 = mesh_1.spawn("rdma_test", RDMATest, device, operation, size_mb) |
| 231 | + |
| 232 | + await actor_0.set_other_actor.call(actor_1) |
| 233 | + |
| 234 | + for i in range(iterations): |
| 235 | + await actor_0.send.call() |
| 236 | + |
| 237 | + # Have both actors print their statistics |
| 238 | + print("\n=== ACTOR 0 (Create Buffer) STATISTICS ===") |
| 239 | + await actor_0.print_statistics.call() |
| 240 | + |
| 241 | + print("\n=== ACTOR 1 (Create Buffer+Transmit) STATISTICS ===") |
| 242 | + await actor_1.print_statistics.call(calc_bwd=True) |
| 243 | + |
| 244 | + await mesh_0.stop() |
| 245 | + await mesh_1.stop() |
| 246 | + |
| 247 | + |
| 248 | +if __name__ == "__main__": |
| 249 | + assert args |
| 250 | + |
| 251 | + # Validate size is multiple of 4 |
| 252 | + if args.size % 4 != 0: |
| 253 | + print(f"Error: --size must be a multiple of 4. Got: {args.size}") |
| 254 | + exit(1) |
| 255 | + |
| 256 | + # Parse and validate device string |
| 257 | + device = args.device.lower() |
| 258 | + if device == "cpu": |
| 259 | + pass # CPU is always valid |
| 260 | + elif device.startswith("cuda:"): |
| 261 | + # Validate CUDA device format |
| 262 | + try: |
| 263 | + device_id = int(device.split(":")[1]) |
| 264 | + if device_id < 0 or device_id > 7: |
| 265 | + print(f"Error: CUDA device ID must be 0-7. Got: {device_id}") |
| 266 | + exit(1) |
| 267 | + except (ValueError, IndexError): |
| 268 | + print( |
| 269 | + f"Error: Invalid device format. Use 'cpu' or 'cuda:X' where X is 0-7. Got: {args.device}" |
| 270 | + ) |
| 271 | + exit(1) |
| 272 | + |
| 273 | + # Check if CUDA is available |
| 274 | + if not torch.cuda.is_available(): |
| 275 | + print("Warning: CUDA requested but not available. Falling back to CPU.") |
| 276 | + device = "cpu" |
| 277 | + elif device_id >= torch.cuda.device_count(): |
| 278 | + print( |
| 279 | + f"Warning: CUDA device {device_id} not available. Available devices: 0-{torch.cuda.device_count()-1}. Falling back to CPU." |
| 280 | + ) |
| 281 | + device = "cpu" |
| 282 | + else: |
| 283 | + print( |
| 284 | + f"Error: Invalid device format. Use 'cpu' or 'cuda:X' where X is 0-7. Got: {args.device}" |
| 285 | + ) |
| 286 | + exit(1) |
| 287 | + |
| 288 | + asyncio.run(main(device, args.iterations, args.operation, args.size)) |
0 commit comments