Skip to content

Commit 8536b17

Browse files
committed
fix test case
Signed-off-by: Cruz Zhao <[email protected]>
1 parent c977b00 commit 8536b17

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

scripts/test_tensor_api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def test_04_put_get_into(self):
292292
"""Verify basic put and get into functionality."""
293293
key = "get_into_test"
294294
tensor = torch.randn(1024, 1024, dtype=torch.float32)
295-
buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
295+
buffer_spacing = 64 * 1024 * 1024
296296
total_buffer_size = buffer_spacing
297297

298298
# Allocate large contiguous buffer
@@ -322,7 +322,7 @@ def test_05_batch_put_get_into(self):
322322
"""Zero copy Batch Get."""
323323
num_tensors = 4
324324
keys, tensors = generate_tensors(num_tensors, 8)
325-
buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
325+
buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
326326
batch_size = len(keys)
327327
total_buffer_size = buffer_spacing * batch_size
328328

@@ -378,7 +378,7 @@ def test_06_put_get_into_with_tp(self):
378378
registered_buffers = [] # Keep track of (ptr, size) for cleanup
379379

380380
for rank in range(tp_size):
381-
buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
381+
buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
382382
total_buffer_size = buffer_spacing
383383

384384
# Allocate buffer for this rank
@@ -445,7 +445,7 @@ def test_07_batch_put_get_into_with_tp(self):
445445

446446
for rank in range(tp_size):
447447
batch_size = len(keys)
448-
buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
448+
buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
449449
total_buffer_size = buffer_spacing * batch_size
450450

451451
# Allocate buffer for this rank
@@ -585,7 +585,8 @@ def test_benchmark_02_tp_batch(self):
585585

586586
def test_benchmark_03_batch_put_get_into(self):
587587
"""Benchmark: Zero copy Batch Get."""
588-
buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot
588+
self.store.remove_all()
589+
buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot
589590
batch_size = len(self.keys)
590591
total_buffer_size = buffer_spacing * batch_size
591592

@@ -645,7 +646,8 @@ def test_benchmark_04_batch_put_get_into_with_tp(self):
645646
tp_size = 4
646647
split_dim = 0
647648
batch_size = len(self.keys)
648-
buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot
649+
self.store.remove_all()
650+
buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot
649651

650652
# Allocate and register a separate buffer for each TP rank
651653
rank_buffers = [] # Store metadata for cleanup

0 commit comments

Comments
 (0)