Skip to content

Commit b3732de

Browse files
committed
Fix SharedTensor anti-pattern in multiprocess tests
Fixed tests that were using the anti-pattern: SharedTensor(handle=handle).tensor This pattern creates a dangling tensor reference because the SharedTensor object is immediately garbage collected, causing __del__ to close the shared memory, which invalidates the tensor reference. Changed all multiprocess worker functions to use context managers: - test_multiprocess_read - test_multiprocess_write - test_multiprocess_bidirectional - test_to_shared_tensor_multiprocess - test_multiple_receivers_close_independently Test plan: - All 65 tests now pass (previously 1 segfault) - python -m pytest tests/unit_tests/util/test_shared_tensor.py -v
1 parent 856c070 commit b3732de

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

tests/unit_tests/util/test_shared_tensor.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,8 @@ def test_multiprocess_read(self):
408408
"""Test reading shared tensor from another process"""
409409

410410
def reader_process(handle_dict, result_queue):
411-
shared = SharedTensor(handle=handle_dict)
412-
tensor = shared.tensor
413-
result_queue.put(tensor.sum().item())
411+
with SharedTensor(handle=handle_dict) as shared:
412+
result_queue.put(shared.tensor.sum().item())
414413

415414
# Create shared tensor in main process
416415
shared = SharedTensor.empty((100, 100), torch.float32)
@@ -435,8 +434,8 @@ def test_multiprocess_write(self):
435434
"""Test writing to shared tensor from another process"""
436435

437436
def writer_process(handle_dict, value):
438-
shared = SharedTensor(handle=handle_dict)
439-
shared.tensor.fill_(value)
437+
with SharedTensor(handle=handle_dict) as shared:
438+
shared.tensor.fill_(value)
440439

441440
# Create empty shared tensor
442441
shared = SharedTensor.empty((50, 50), torch.float32)
@@ -458,11 +457,10 @@ def test_multiprocess_bidirectional(self):
458457
"""Test bidirectional communication"""
459458

460459
def worker_process(input_handle, output_handle):
461-
input_tensor = SharedTensor(handle=input_handle).tensor
462-
output_tensor = SharedTensor(handle=output_handle).tensor
463-
464-
# Compute: output = input * 2
465-
output_tensor.copy_(input_tensor * 2)
460+
with SharedTensor(handle=input_handle) as input_shared:
461+
with SharedTensor(handle=output_handle) as output_shared:
462+
# Compute: output = input * 2
463+
output_shared.tensor.copy_(input_shared.tensor * 2)
466464

467465
# Create input and output tensors
468466
input_shared = SharedTensor.empty((100, 100), torch.float32)
@@ -605,8 +603,8 @@ def test_to_shared_tensor_multiprocess(self):
605603
"""Test to_shared_tensor in multiprocess scenario"""
606604

607605
def worker_process(handle, result_queue):
608-
shared = handle.to_shared_tensor()
609-
result_queue.put(shared.tensor.sum().item())
606+
with handle.to_shared_tensor() as shared:
607+
result_queue.put(shared.tensor.sum().item())
610608

611609
original = SharedTensor.empty((50, 50), torch.float32)
612610
original.tensor.fill_(3.0)
@@ -831,10 +829,9 @@ def test_multiple_receivers_close_independently(self):
831829
"""Test that multiple receivers can close independently"""
832830

833831
def receiver_process(handle, value, result_queue):
834-
shared = SharedTensor(handle=handle)
835-
result = shared.tensor[0, 0].item() == value
836-
shared.close() # Each receiver closes its own reference
837-
result_queue.put(result)
832+
with SharedTensor(handle=handle) as shared:
833+
result = shared.tensor[0, 0].item() == value
834+
result_queue.put(result)
838835

839836
# Creator
840837
shared = SharedTensor.empty((10, 10), torch.float32)

0 commit comments

Comments
 (0)