Skip to content

Commit 260ae9c

Browse files
committed
✨ invoke conversion at load time
Signed-off-by: Joe Runde <[email protected]>
1 parent a3620be commit 260ae9c

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

src/vllm_tgis_adapter/grpc/adapters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
import dataclasses
1212
import json
1313
import re
14+
import tempfile
1415
from pathlib import Path
1516
from typing import TYPE_CHECKING
1617

1718
from vllm.lora.request import LoRARequest
1819
from vllm.prompt_adapter.request import PromptAdapterRequest
1920

21+
from vllm_tgis_adapter.logging import init_logger
22+
from vllm_tgis_adapter.tgis_utils.convert_pt_to_prompt import convert_pt_to_peft
23+
2024
from .validation import TGISValidationError
2125

2226
if TYPE_CHECKING:
@@ -30,6 +34,8 @@
3034

3135
VALID_ADAPTER_ID_PATTERN = re.compile("[/\\w\\-]+")
3236

37+
logger = init_logger(__name__)
38+
3339

3440
@dataclasses.dataclass
3541
class AdapterMetadata:
@@ -82,6 +88,20 @@ async def validate_adapters(
8288
if global_thread_pool is None:
8389
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)
8490

91+
# 🌶️🌶️🌶️ Check for caikit-style adapters first
92+
if (
93+
Path(local_adapter_path).exists()
94+
and (Path(local_adapter_path) / "decoder.pt").exists()
95+
):
96+
# Create new temporary directory and convert to peft format there
97+
# NB: This requires write access to /tmp
98+
# Intentionally setting delete=False, we need the new adapter
99+
# files to exist for the life of the process
100+
logger.info("Converting caikit-style adapter %s to peft format", adapter_id)
101+
temp_dir = tempfile.TemporaryDirectory(delete=False)
102+
convert_pt_to_peft(local_adapter_path, temp_dir.name)
103+
local_adapter_path = temp_dir.name
104+
85105
adapter_config = await loop.run_in_executor(
86106
global_thread_pool,
87107
_load_adapter_config_from_file,

src/vllm_tgis_adapter/tgis_utils/convert_pt_to_prompt.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# users the ability to be able to run it independently without
33
# having to install vllm as a dependency
44
import argparse
5-
import sys
65
from pathlib import Path
76

87
import torch
@@ -30,17 +29,15 @@ def convert_pt_to_peft(input_dir: str, output_dir: str) -> None:
3029
# read decoder.pt file
3130
decoder_pt_path = Path(input_dir) / "decoder.pt"
3231
if not decoder_pt_path.exists():
33-
print(f"No decoder.pt model found in path {decoder_pt_path}") # noqa: T201
34-
sys.exit()
32+
raise ValueError(f"No decoder.pt model found in path {decoder_pt_path}")
3533

3634
# error if encoder.pt file exists
3735
encoder_pt_path = Path(input_dir) / "encoder.pt"
3836
if encoder_pt_path.exists():
39-
print( # noqa: T201
37+
raise ValueError(
4038
f"encoder.pt model found in path {encoder_pt_path}, \
4139
encoder-decoder models are not yet supported, sorry!"
4240
)
43-
sys.exit()
4441

4542
# check output dir
4643
if output_dir is None:
@@ -58,8 +55,7 @@ def convert_pt_to_peft(input_dir: str, output_dir: str) -> None:
5855

5956
# error if output_dir is file
6057
if output_path.is_file():
61-
print(f"File found instead of dir {output_path}, exiting...") # noqa: T201
62-
sys.exit()
58+
raise ValueError(f"File found instead of dir {output_path}")
6359

6460
# load tensors from decoder.pt and save to .safetensors
6561
decoder_tensors = torch.load(decoder_pt_path, weights_only=True)
@@ -73,6 +69,7 @@ def convert_pt_to_peft(input_dir: str, output_dir: str) -> None:
7369
adapter_config = {
7470
"num_virtual_tokens": decoder_tensors.shape[0],
7571
"peft_type": "PROMPT_TUNING",
72+
"base_model_name_or_path": "this-is-a/temporary-conversion",
7673
}
7774

7875
with open(output_path / "adapter_config.json", "w") as config_file:

0 commit comments

Comments
 (0)