Skip to content

Commit aba780e

Browse files
authored
Added fix for assertion error during parallel access of mountpoint client. Added integration tests for parallel access of mountpoint client. (#237)
* Added integration test to test parallel saving of checkpoints * Added integration test for parallel access of mountpoint client * Added fix to prevent assert error in multi threaded environment * Modified integration test * Modified integration test and lock approach * Added comments * Added some more comments
1 parent 9809622 commit aba780e

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

s3torchconnector/src/s3torchconnector/_s3client/_s3client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import logging
55
import os
6+
import threading
67
from functools import partial
78
from typing import Optional, Any
89

@@ -32,6 +33,9 @@ def _identity(obj: Any) -> Any:
3233
return obj
3334

3435

36+
_client_lock = threading.Lock()
37+
38+
3539
class S3Client:
3640
def __init__(
3741
self,
@@ -51,10 +55,15 @@ def __init__(
5155

5256
@property
5357
def _client(self) -> MountpointS3Client:
58+
# This is a fast check to avoid acquiring the lock unnecessarily.
5459
if self._client_pid is None or self._client_pid != os.getpid():
55-
self._client_pid = os.getpid()
56-
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
57-
self._real_client = self._client_builder()
60+
# Acquire the lock to ensure thread-safety when creating the client.
61+
with _client_lock:
62+
# This double-check ensures that the client is only created once.
63+
if self._client_pid is None or self._client_pid != os.getpid():
64+
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
65+
self._real_client = self._client_builder()
66+
self._client_pid = os.getpid()
5867
assert self._real_client is not None
5968
return self._real_client
6069

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import random
3+
import time
4+
import threading
5+
import pytest
6+
from s3torchconnector._s3client import S3Client
7+
from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client
8+
9+
10+
class S3ClientWithoutLock(S3Client):
11+
@property
12+
def _client(self) -> MountpointS3Client:
13+
if self._client_pid is None or self._client_pid != os.getpid():
14+
self._client_pid = os.getpid()
15+
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
16+
self._real_client = self._client_builder()
17+
assert self._real_client is not None
18+
return self._real_client
19+
20+
def _client_builder(self):
21+
time.sleep(1)
22+
return super()._client_builder()
23+
24+
25+
class S3ClientWithLock(S3Client):
26+
def _client_builder(self):
27+
time.sleep(1)
28+
return super()._client_builder()
29+
30+
31+
def access_client(client, error_event):
32+
try:
33+
if not error_event.is_set():
34+
client._client
35+
print(f"Successfully accessed by thread {threading.current_thread().name}")
36+
except AssertionError as e:
37+
print(f"AssertionError in thread {threading.current_thread().name}: {e}")
38+
error_event.set()
39+
40+
41+
def test_multiple_thread_accessing_mountpoint_client_in_parallel_without_lock():
42+
print("Running test without lock...")
43+
client = S3ClientWithoutLock("us-west-2")
44+
if not access_mountpoint_client_in_parallel(client):
45+
pytest.fail(
46+
"Test failed as AssertionError did not happen in one of the threads."
47+
)
48+
49+
50+
def test_multiple_thread_accessing_mountpoint_client_in_parallel_with_lock():
51+
print("Running test with lock...")
52+
client = S3ClientWithLock("us-west-2")
53+
if access_mountpoint_client_in_parallel(client):
54+
pytest.fail("Test failed as AssertionError happened in one of the threads.")
55+
56+
57+
def access_mountpoint_client_in_parallel(client):
58+
59+
error_event = threading.Event()
60+
# Create and start multiple threads
61+
accessor_threads = []
62+
num_accessor_threads = 10
63+
64+
for i in range(num_accessor_threads):
65+
if error_event.is_set():
66+
break
67+
accessor_thread = threading.Thread(
68+
target=access_client,
69+
args=(
70+
client,
71+
error_event,
72+
),
73+
name=f"Accessor-{i + 1}",
74+
)
75+
accessor_threads.append(accessor_thread)
76+
accessor_thread.start()
77+
78+
for thread in accessor_threads:
79+
thread.join()
80+
81+
return error_event.is_set()

0 commit comments

Comments
 (0)