Skip to content

Commit 5a863c3

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Support > 1 GB put/send with chunking (#1367)
Summary: Pull Request resolved: #1367 Adds chunking logic at rust layer. unblocks > 1 gb single put/get operations To test: rdma_load_test -- --device cuda:0 --operation ping-pong --iterations 100 --size 1512 === ACTOR 0 (Create Buffer) STATISTICS === [0] TIMING RESULTS: [0] Average time per operation: 1.526 ms [0] Minimum time per operation: 0.537 ms [0] Maximum time per operation: 92.305 ms [0] Standard deviation: 9.170 ms === ACTOR 1 (Create Buffer+Transmit) STATISTICS === [0] TIMING RESULTS: [0] Average time per operation: 35.145 ms [0] Minimum time per operation: 16.979 ms [0] Maximum time per operation: 155.070 ms [0] Standard deviation: 18.571 ms [0] [0] ============================================================ [0] RDMA PING-PONG LOAD TEST RESULTS (CUDA:0) [0] ============================================================ [0] Total iterations completed: 100 [0] Average data per operation: 1519.6 MB [0] Total data transferred: 151956.0 MB [0] [0] [0] BANDWIDTH RESULTS: [0] Average bandwidth: 362.70 Gbps [0] Maximum bandwidth: 750.74 Gbps [0] Minimum bandwidth: 82.20 Gbps [0] ============================================================ Reviewed By: allenwang28 Differential Revision: D83499085 fbshipit-source-id: 0f655e1f678993a1c80b5058f11715c0b8124323
1 parent 9532ba1 commit 5a863c3

File tree

2 files changed

+179
-37
lines changed

2 files changed

+179
-37
lines changed

monarch_rdma/src/rdma_components.rs

Lines changed: 82 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
//! 6. Poll for completions
4141
//! 7. Resources are cleaned up when dropped
4242
43+
/// Maximum size for a single RDMA operation in bytes (1 GiB)
44+
const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45+
4346
use std::ffi::CStr;
4447
use std::fs;
4548
use std::io::Error;
@@ -788,20 +791,37 @@ impl RdmaQueuePair {
788791
}
789792

790793
pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
791-
let idx = self.send_wqe_idx;
792-
self.send_wqe_idx += 1;
793-
self.post_op(
794-
lhandle.addr,
795-
lhandle.lkey,
796-
lhandle.size,
797-
idx,
798-
true,
799-
RdmaOperation::Write,
800-
rhandle.addr,
801-
rhandle.rkey,
802-
)
803-
.unwrap();
804-
self.send_db_idx += 1;
794+
let total_size = lhandle.size;
795+
if rhandle.size < total_size {
796+
return Err(anyhow::anyhow!(
797+
"Remote buffer size ({}) is smaller than local buffer size ({})",
798+
rhandle.size,
799+
total_size
800+
));
801+
}
802+
803+
let mut remaining = total_size;
804+
let mut offset = 0;
805+
while remaining > 0 {
806+
let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
807+
let idx = self.send_wqe_idx;
808+
self.send_wqe_idx += 1;
809+
self.post_op(
810+
lhandle.addr + offset,
811+
lhandle.lkey,
812+
chunk_size,
813+
idx,
814+
true,
815+
RdmaOperation::Write,
816+
rhandle.addr + offset,
817+
rhandle.rkey,
818+
)?;
819+
self.send_db_idx += 1;
820+
821+
remaining -= chunk_size;
822+
offset += chunk_size;
823+
}
824+
805825
Ok(())
806826
}
807827

@@ -932,20 +952,38 @@ impl RdmaQueuePair {
932952
}
933953

934954
pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
935-
let idx = self.send_wqe_idx;
936-
self.send_wqe_idx += 1;
937-
self.post_op(
938-
lhandle.addr,
939-
lhandle.lkey,
940-
lhandle.size,
941-
idx,
942-
true,
943-
RdmaOperation::Read,
944-
rhandle.addr,
945-
rhandle.rkey,
946-
)
947-
.unwrap();
948-
self.send_db_idx += 1;
955+
let total_size = lhandle.size;
956+
if rhandle.size < total_size {
957+
return Err(anyhow::anyhow!(
958+
"Remote buffer size ({}) is smaller than local buffer size ({})",
959+
rhandle.size,
960+
total_size
961+
));
962+
}
963+
964+
let mut remaining = total_size;
965+
let mut offset = 0;
966+
967+
while remaining > 0 {
968+
let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
969+
let idx = self.send_wqe_idx;
970+
self.send_wqe_idx += 1;
971+
self.post_op(
972+
lhandle.addr + offset,
973+
lhandle.lkey,
974+
chunk_size,
975+
idx,
976+
true,
977+
RdmaOperation::Read,
978+
rhandle.addr + offset,
979+
rhandle.rkey,
980+
)?;
981+
self.send_db_idx += 1;
982+
983+
remaining -= chunk_size;
984+
offset += chunk_size;
985+
}
986+
949987
Ok(())
950988
}
951989

@@ -1122,7 +1160,7 @@ impl RdmaQueuePair {
11221160
///
11231161
/// # Arguments
11241162
///
1125-
/// * `target` - Which completion queue(s) to poll (Send, Receive, or Both)
1163+
/// * `target` - Which completion queue(s) to poll (Send, Receive)
11261164
///
11271165
/// # Returns
11281166
///
@@ -1168,9 +1206,10 @@ impl RdmaQueuePair {
11681206
// This should be a send completion - verify it's the one we're waiting for
11691207
if wc.wr_id() == self.send_cq_idx {
11701208
self.send_cq_idx += 1;
1209+
}
1210+
// finished polling, return the last completion
1211+
if self.send_cq_idx == self.send_db_idx {
11711212
return Ok(Some(IbvWc::from(wc)));
1172-
} else {
1173-
// This completion is for a different operation - keep polling
11741213
}
11751214
}
11761215
}
@@ -1193,17 +1232,23 @@ impl RdmaQueuePair {
11931232
if !wc.is_valid() {
11941233
if let Some((status, vendor_err)) = wc.error() {
11951234
return Err(anyhow::anyhow!(
1196-
"Receive work completion failed with status: {:?}, vendor error: {}",
1235+
"Recv work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}",
11971236
status,
1198-
vendor_err
1237+
vendor_err,
1238+
wc.wr_id(),
1239+
self.recv_cq_idx,
11991240
));
12001241
}
12011242
}
12021243

1203-
// This should be a receive completion
1204-
self.recv_cq_idx += 1;
1205-
1206-
return Ok(Some(IbvWc::from(wc)));
1244+
// This should be a send completion - verify it's the one we're waiting for
1245+
if wc.wr_id() == self.recv_cq_idx {
1246+
self.recv_cq_idx += 1;
1247+
}
1248+
// finished polling, return the last completion
1249+
if self.recv_cq_idx == self.recv_db_idx {
1250+
return Ok(Some(IbvWc::from(wc)));
1251+
}
12071252
}
12081253
}
12091254

python/tests/test_rdma.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,100 @@ def test_gpu_trainer_generator_sync() -> None:
260260
for _ in range(1):
261261
trainer.weights_ready.call().get()
262262
generator.update_weights.call().get()
263+
264+
265+
@needs_rdma
266+
async def test_rdma_concurrent_2gb_writes_in_order():
267+
"""Test concurrent 2GB RDMA buffer writes with reverse-order awaiting"""
268+
proc = this_host().spawn_procs(per_host={"processes": 1})
269+
num_elem = 500_000_000 # 500M elements
270+
271+
class BufferOwnerActor(Actor):
272+
def __init__(self):
273+
# Create a 2GB buffer (500M float32 elements * 4 bytes = 2GB)
274+
self.data = torch.zeros(num_elem, dtype=torch.float32)
275+
self.rdma_buffer = None
276+
277+
@endpoint
278+
async def create_buffer(self) -> RDMABuffer:
279+
"""Create a 2GB RDMABuffer"""
280+
byte_tensor = self.data.view(torch.uint8).flatten()
281+
self.rdma_buffer = RDMABuffer(byte_tensor)
282+
return self.rdma_buffer
283+
284+
@endpoint
285+
async def get_buffer_data(self) -> torch.Tensor:
286+
"""Return the current buffer data for verification"""
287+
return self.data
288+
289+
class WriterActor(Actor):
290+
def __init__(self):
291+
# Create a 2GB buffer (500M float32 elements * 4 bytes = 2GB)
292+
self.tensor_a = torch.ones(
293+
num_elem, dtype=torch.float32
294+
) # Will receive data
295+
self.tensor_b = torch.full(
296+
(num_elem,), 2.0, dtype=torch.float32
297+
) # Will send data
298+
299+
@endpoint
300+
async def perform_concurrent_writes(self, buffer: RDMABuffer):
301+
"""Perform concurrent read/write operations and await in reverse order"""
302+
# Convert tensors to byte views for RDMA
303+
byte_tensor_a = self.tensor_a.view(torch.uint8).flatten()
304+
byte_tensor_b = self.tensor_b.view(torch.uint8).flatten()
305+
306+
# Start both operations concurrently
307+
future_a = buffer.read_into(
308+
byte_tensor_a, timeout=10
309+
) # Read FROM buffer INTO tensor_a
310+
future_b = buffer.write_from(
311+
byte_tensor_b, timeout=10
312+
) # Write FROM tensor_b INTO buffer
313+
314+
# Await in reverse order - sets actual execution order
315+
await future_b # Await write operation first
316+
await future_a # Await read operation second
317+
318+
return "SUCCESS"
319+
320+
@endpoint
321+
async def get_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
322+
"""Return both tensors for verification"""
323+
return (self.tensor_a, self.tensor_b)
324+
325+
# Create actors
326+
buffer_owner = proc.spawn("buffer_owner", BufferOwnerActor)
327+
writer = proc.spawn("writer", WriterActor)
328+
329+
# Create the 2GB RDMA buffer
330+
buffer = await buffer_owner.create_buffer.call_one()
331+
print(f"✓ Created 2GB RDMA buffer (size: {buffer.size() / (1024**3):.2f} GB)")
332+
333+
# Perform concurrent writes with reverse-order awaiting
334+
result = await writer.perform_concurrent_writes.call_one(buffer)
335+
assert result == "SUCCESS", f"Concurrent writes failed: {result}"
336+
337+
# Verify the data flow worked correctly using torch.allclose
338+
tensor_a_actual, tensor_b_actual = await writer.get_tensors.call_one()
339+
buffer_data_actual = await buffer_owner.get_buffer_data.call_one()
340+
341+
expected_result = torch.full((num_elem,), 2.0, dtype=torch.float32)
342+
343+
# Verify using torch.allclose
344+
assert torch.allclose(
345+
tensor_a_actual, expected_result
346+
), "tensor_a does not match expected 2.0s"
347+
assert torch.allclose(
348+
tensor_b_actual, expected_result
349+
), "tensor_b does not match expected 2.0s"
350+
351+
assert torch.allclose(
352+
buffer_data_actual, expected_result
353+
), "RDMABuffer does not contain expected 2.0s"
354+
355+
print("✓ Concurrent 2GB operations completed successfully")
356+
357+
# Drop the buffer
358+
await buffer.drop()
359+
print("✓ Buffer dropped successfully")

0 commit comments

Comments
 (0)