Skip to content

Commit eca967d

Browse files
committed
⚡ guard concurrent adapter loads
Signed-off-by: Joe Runde <[email protected]>
1 parent e02df23 commit eca967d

File tree

2 files changed

+56
-25
lines changed

2 files changed

+56
-25
lines changed

src/vllm_tgis_adapter/grpc/adapters.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class AdapterStore:
5050
cache_path: str # Path to local store of adapters to load from
5151
adapters: dict[str, AdapterMetadata]
5252
next_unique_id: int = 1
53+
load_locks: dict[str, asyncio.Lock] = dataclasses.field(default_factory=dict)
5354

5455

5556
async def validate_adapters(
@@ -78,31 +79,35 @@ async def validate_adapters(
7879
if not adapter_id or not adapter_store:
7980
return {}
8081

81-
# If not already cached, we need to validate that files exist and
82-
# grab the type out of the adapter_config.json file
83-
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
84-
_reject_bad_adapter_id(adapter_id)
85-
local_adapter_path = str(Path(adapter_store.cache_path) / adapter_id)
86-
87-
loop = asyncio.get_running_loop()
88-
if global_thread_pool is None:
89-
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
90-
91-
# Increment the unique adapter id counter here in async land where we don't
92-
# need to deal with thread-safety
93-
unique_id = adapter_store.next_unique_id
94-
adapter_store.next_unique_id += 1
95-
96-
adapter_metadata = await loop.run_in_executor(
97-
global_thread_pool,
98-
_load_adapter_metadata,
99-
adapter_id,
100-
local_adapter_path,
101-
unique_id,
102-
)
103-
104-
# Add to cache
105-
adapter_store.adapters[adapter_id] = adapter_metadata
82+
# Guard against concurrent access for the same adapter
83+
async with adapter_store.load_locks.setdefault(adapter_id, asyncio.Lock()):
84+
# If not already cached, we need to validate that files exist and
85+
# grab the type out of the adapter_config.json file
86+
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
87+
_reject_bad_adapter_id(adapter_id)
88+
local_adapter_path = str(Path(adapter_store.cache_path) / adapter_id)
89+
90+
loop = asyncio.get_running_loop()
91+
if global_thread_pool is None:
92+
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
93+
max_workers=2
94+
)
95+
96+
# Increment the unique adapter id counter here in async land where we don't
97+
# need to deal with thread-safety
98+
unique_id = adapter_store.next_unique_id
99+
adapter_store.next_unique_id += 1
100+
101+
adapter_metadata = await loop.run_in_executor(
102+
global_thread_pool,
103+
_load_adapter_metadata,
104+
adapter_id,
105+
local_adapter_path,
106+
unique_id,
107+
)
108+
109+
# Add to cache
110+
adapter_store.adapters[adapter_id] = adapter_metadata
106111

107112
# Build the proper vllm request object
108113
if adapter_metadata.adapter_type == "LORA":

tests/test_adapters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from pathlib import Path
23

34
import pytest
@@ -112,3 +113,28 @@ async def test_store_handles_multiple_adapters():
112113
adapters_1["lora_request"].lora_int_id
113114
< adapters_2["prompt_adapter_request"].prompt_adapter_id
114115
)
116+
117+
118+
@pytest.mark.asyncio
119+
async def test_cache_handles_concurrent_loads():
120+
# Check that the cache does not hammer the filesystem when accessed concurrently
121+
# Specifically, when concurrent requests for the same new adapter arrive
122+
123+
adapter_store = AdapterStore(cache_path=FIXTURES_DIR, adapters={})
124+
# Use a caikit-style adapter that requires conversion, to test worst case
125+
adapter_name = "bloom_sentiment_1"
126+
request = BatchedGenerationRequest(
127+
adapter_id=adapter_name,
128+
)
129+
130+
# Fire off a bunch of concurrent requests for the same new adapter
131+
tasks = [
132+
asyncio.create_task(validate_adapters(request, adapter_store=adapter_store))
133+
for _ in range(1000)
134+
]
135+
136+
# Await all tasks
137+
await asyncio.gather(*tasks)
138+
139+
# The adapter store should have only given out one unique ID
140+
assert adapter_store.next_unique_id == 2

0 commit comments

Comments
 (0)