@@ -292,7 +292,7 @@ def test_04_put_get_into(self):
292292 """Verify basic put and get into functionality."""
293293 key = "get_into_test"
294294 tensor = torch .randn (1024 , 1024 , dtype = torch .float32 )
295- buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
295+ buffer_spacing = 64 * 1024 * 1024
296296 total_buffer_size = buffer_spacing
297297
298298 # Allocate large contiguous buffer
@@ -322,7 +322,7 @@ def test_05_batch_put_get_into(self):
322322 """Zero copy Batch Get."""
323323 num_tensors = 4
324324 keys , tensors = generate_tensors (num_tensors , 8 )
325- buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
325+ buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
326326 batch_size = len (keys )
327327 total_buffer_size = buffer_spacing * batch_size
328328
@@ -378,7 +378,7 @@ def test_06_put_get_into_with_tp(self):
378378 registered_buffers = [] # Keep track of (ptr, size) for cleanup
379379
380380 for rank in range (tp_size ):
381- buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
381+ buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
382382 total_buffer_size = buffer_spacing
383383
384384 # Allocate buffer for this rank
@@ -445,7 +445,7 @@ def test_07_batch_put_get_into_with_tp(self):
445445
446446 for rank in range (tp_size ):
447447 batch_size = len (keys )
448- buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot
448+ buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot
449449 total_buffer_size = buffer_spacing * batch_size
450450
451451 # Allocate buffer for this rank
@@ -585,7 +585,8 @@ def test_benchmark_02_tp_batch(self):
585585
586586 def test_benchmark_03_batch_put_get_into (self ):
587587 """Benchmark: Zero copy Batch Get."""
588- buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot
588+ self .store .remove_all ()
589+ buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot
589590 batch_size = len (self .keys )
590591 total_buffer_size = buffer_spacing * batch_size
591592
@@ -645,7 +646,8 @@ def test_benchmark_04_batch_put_get_into_with_tp(self):
645646 tp_size = 4
646647 split_dim = 0
647648 batch_size = len (self .keys )
648- buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot
649+ self .store .remove_all ()
650+ buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot
649651
650652 # Allocate and register a separate buffer for each TP rank
651653 rank_buffers = [] # Store metadata for cleanup
0 commit comments