Skip to content

Commit 76a34f3

Browse files
committed
♻️ move all file access to threadpool
Signed-off-by: Joe Runde <[email protected]>
1 parent 01f11e5 commit 76a34f3

File tree

7 files changed

+162
-28
lines changed

7 files changed

+162
-28
lines changed

src/vllm_tgis_adapter/grpc/adapters.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,35 +88,20 @@ async def validate_adapters(
8888
if global_thread_pool is None:
8989
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
9090

91-
# 🌶️🌶️🌶️ Check for caikit-style adapters first
92-
if (
93-
Path(local_adapter_path).exists()
94-
and (Path(local_adapter_path) / "decoder.pt").exists()
95-
):
96-
# Create new temporary directory and convert to peft format there
97-
# NB: This requires write access to /tmp
98-
# Intentionally setting delete=False, we need the new adapter
99-
# files to exist for the life of the process
100-
logger.info("Converting caikit-style adapter %s to peft format", adapter_id)
101-
temp_dir = tempfile.TemporaryDirectory(delete=False)
102-
convert_pt_to_peft(local_adapter_path, temp_dir.name)
103-
local_adapter_path = temp_dir.name
104-
105-
adapter_config = await loop.run_in_executor(
91+
# Increment the unique adapter id counter here in async land where we don't
92+
# need to deal with thread-safety
93+
unique_id = adapter_store.next_unique_id
94+
adapter_store.next_unique_id += 1
95+
96+
adapter_metadata = await loop.run_in_executor(
10697
global_thread_pool,
107-
_load_adapter_config_from_file,
98+
_load_adapter_metadata,
10899
adapter_id,
109100
local_adapter_path,
101+
unique_id,
110102
)
111-
adapter_type = adapter_config.get("peft_type", None)
112103

113104
# Add to cache
114-
adapter_metadata = AdapterMetadata(
115-
unique_id=adapter_store.next_unique_id,
116-
adapter_type=adapter_type,
117-
full_path=local_adapter_path,
118-
full_config=adapter_config,
119-
)
120105
adapter_store.adapters[adapter_id] = adapter_metadata
121106

122107
# Build the proper vllm request object
@@ -142,8 +127,8 @@ async def validate_adapters(
142127
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503
143128

144129

145-
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict:
146-
"""Get adapter from file.
130+
def _load_adapter_metadata(adapter_id: str, adapter_path: str, unique_id: int) -> dict:
131+
"""Get adapter metadata from files.
147132
148133
Performs all the filesystem access required to deduce the type
149134
of the adapter. It's run in a separate thread pool executor so that file
@@ -154,17 +139,35 @@ def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict:
154139
adapter_id, "directory does not exist"
155140
)
156141

142+
# 🌶️🌶️🌶️ Check for caikit-style adapters first
143+
if Path(adapter_path).exists() and (Path(adapter_path) / "decoder.pt").exists():
144+
# Create new temporary directory and convert to peft format there
145+
# NB: This requires write access to /tmp
146+
# Intentionally setting delete=False, we need the new adapter
147+
# files to exist for the life of the process
148+
logger.info("Converting caikit-style adapter %s to peft format", adapter_id)
149+
temp_dir = tempfile.TemporaryDirectory(delete=False)
150+
convert_pt_to_peft(adapter_path, temp_dir.name)
151+
adapter_path = temp_dir.name
152+
157153
adapter_config_path = Path(adapter_path) / "adapter_config.json"
158154
if not Path(adapter_config_path).exists():
159155
TGISValidationError.AdapterNotFound.error(
160156
adapter_id, "invalid adapter: no adapter_config.json found"
161157
)
162158

163-
# NB: blocks event loop
164159
with open(adapter_config_path) as adapter_config_file:
165160
adapter_config = json.load(adapter_config_file)
166161

167-
return adapter_config
162+
adapter_type = adapter_config.get("peft_type", None)
163+
adapter_metadata = AdapterMetadata(
164+
unique_id=unique_id,
165+
adapter_type=adapter_type,
166+
full_path=adapter_path,
167+
full_config=adapter_config,
168+
)
169+
170+
return adapter_metadata
168171

169172

170173
def _reject_bad_adapter_id(adapter_id: str) -> None:
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"base_model_name_or_path": "bigscience/bloomz-560m",
3+
"inference_mode": true,
4+
"num_attention_heads": 16,
5+
"num_layers": 24,
6+
"num_transformer_submodules": 1,
7+
"num_virtual_tokens": 8,
8+
"peft_type": "PROMPT_TUNING",
9+
"prompt_tuning_init": "TEXT",
10+
"prompt_tuning_init_text": "Classify if the tweet is a complaint or not:",
11+
"task_type": "CAUSAL_LM",
12+
"token_dim": 1024,
13+
"tokenizer_name_or_path": "bigscience/bloomz-560m"
14+
}
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The adapter_model.safetensors file here is just a dummy file for tests to pass that will not actually need to load it
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"alpha_pattern": {},
3+
"auto_mapping": null,
4+
"base_model_name_or_path": "/granite/granite-3b-base-v2/step_75000_ckpt",
5+
"bias": "none",
6+
"fan_in_fan_out": false,
7+
"inference_mode": true,
8+
"init_lora_weights": true,
9+
"layer_replication": null,
10+
"layers_pattern": null,
11+
"layers_to_transform": null,
12+
"loftq_config": {},
13+
"lora_alpha": 16,
14+
"lora_dropout": 0.05,
15+
"megatron_config": null,
16+
"megatron_core": "megatron.core",
17+
"modules_to_save": null,
18+
"peft_type": "LORA",
19+
"r": 8,
20+
"rank_pattern": {},
21+
"revision": null,
22+
"target_modules": [
23+
"c_attn",
24+
"c_fc",
25+
"c_proj"
26+
],
27+
"task_type": "CAUSAL_LM",
28+
"use_dora": false,
29+
"use_rslora": false
30+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fake weights

tests/test_adapters.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pathlib import Path
22

33
import pytest
4+
from vllm.lora.request import LoRARequest
5+
from vllm.prompt_adapter.request import PromptAdapterRequest
46

57
from vllm_tgis_adapter.grpc.adapters import AdapterStore, validate_adapters
68
from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
@@ -11,7 +13,8 @@
1113

1214

1315
@pytest.mark.asyncio
14-
async def test_validate_adapters():
16+
async def test_caikit_prompt_adapter():
17+
# Checks that decoder.pt style adapters from caikit_nlp are loaded correctly
1518
adapter_name = "bloom_sentiment_1"
1619
request = BatchedGenerationRequest(
1720
adapter_id=adapter_name,
@@ -20,10 +23,92 @@ async def test_validate_adapters():
2023
adapters = await validate_adapters(
2124
request, AdapterStore(cache_path=FIXTURES_DIR, adapters={})
2225
)
26+
# Ensure we created a prompt adapter request
2327
assert "prompt_adapter_request" in adapters
2428
assert adapters["prompt_adapter_request"].prompt_adapter_name == adapter_name
2529
adapter_path = adapters["prompt_adapter_request"].prompt_adapter_local_path
2630
assert adapter_path is not None
31+
assert isinstance(adapters["prompt_adapter_request"], PromptAdapterRequest)
2732

33+
# make sure the converted adapter is not in the cache directory
34+
assert str(FIXTURES_DIR) not in adapter_path
35+
assert "/tmp" in adapter_path
36+
37+
# Check for the converted artifacts
2838
assert Path.exists(Path(adapter_path) / "adapter_config.json")
2939
assert Path.exists(Path(adapter_path) / "adapter_model.safetensors")
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_prompt_adapter():
44+
adapter_name = "bloomz-560m-prompt-adapter"
45+
request = BatchedGenerationRequest(
46+
adapter_id=adapter_name,
47+
)
48+
49+
adapters = await validate_adapters(
50+
request, AdapterStore(cache_path=FIXTURES_DIR, adapters={})
51+
)
52+
# Ensure we created a prompt adapter request
53+
assert "prompt_adapter_request" in adapters
54+
assert adapters["prompt_adapter_request"].prompt_adapter_name == adapter_name
55+
assert isinstance(adapters["prompt_adapter_request"], PromptAdapterRequest)
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_lora_adapter():
60+
adapter_name = "granite-3b-code-instruct-lora"
61+
request = BatchedGenerationRequest(
62+
adapter_id=adapter_name,
63+
)
64+
65+
adapters = await validate_adapters(
66+
request, AdapterStore(cache_path=FIXTURES_DIR, adapters={})
67+
)
68+
# Ensure we created a LoRA adapter request
69+
assert "lora_request" in adapters
70+
assert adapters["lora_request"].lora_name == adapter_name
71+
assert isinstance(adapters["lora_request"], LoRARequest)
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_adapters_are_cached():
76+
adapter_name = "granite-3b-code-instruct-lora"
77+
request = BatchedGenerationRequest(
78+
adapter_id=adapter_name,
79+
)
80+
81+
adapter_store = AdapterStore(cache_path=FIXTURES_DIR, adapters={})
82+
83+
adapters_1 = await validate_adapters(request, adapter_store=adapter_store)
84+
adapters_2 = await validate_adapters(request, adapter_store=adapter_store)
85+
86+
# Metadata is only fetched and cached once
87+
assert len(adapter_store.adapters) == 1
88+
# Same unique ID is re-used for the second request
89+
assert (
90+
adapters_1["lora_request"].lora_int_id == adapters_2["lora_request"].lora_int_id
91+
)
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_store_handles_multiple_adapters():
96+
adapter_store = AdapterStore(cache_path=FIXTURES_DIR, adapters={})
97+
98+
adapter_name = "granite-3b-code-instruct-lora"
99+
request = BatchedGenerationRequest(
100+
adapter_id=adapter_name,
101+
)
102+
adapters_1 = await validate_adapters(request, adapter_store=adapter_store)
103+
104+
adapter_name = "bloomz-560m-prompt-adapter"
105+
request = BatchedGenerationRequest(
106+
adapter_id=adapter_name,
107+
)
108+
adapters_2 = await validate_adapters(request, adapter_store=adapter_store)
109+
110+
assert len(adapter_store.adapters) == 2
111+
assert (
112+
adapters_1["lora_request"].lora_int_id
113+
< adapters_2["prompt_adapter_request"].prompt_adapter_id
114+
)

0 commit comments

Comments
 (0)