Skip to content

Commit 8e4f35c

Browse files
joerundenjhill
authored andcommitted
Support tuned prompts in peft adapter format
This PR adds support for loading prefixes from saved peft adapters. When models with peft adapters are saved with model.save_pretrained(output_path), the resulting adapter_model.safetensors files can now be loaded as prefixes. This only supports peft adapters created using prompt tuning, as we load the prompt_embeddings tensor out of the saved adapter
1 parent 28c5f5f commit 8e4f35c

File tree

5 files changed

+220
-46
lines changed

5 files changed

+220
-46
lines changed
Binary file not shown.

integration_tests/test_cases_tinyllama.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,40 @@
175175
stopReason: MAX_TOKENS
176176
text: ' Once he can go to the park and play with his friends.'
177177

178+
# Prompt loaded from peft adapter
179+
- name: Greedy with tuned peft adapter prefix
180+
request:
181+
prefixId: tinyllama_peft_adapter
182+
params:
183+
method: GREEDY
184+
stopping:
185+
maxNewTokens: 13
186+
requests:
187+
- {"text": ""}
188+
response:
189+
responses:
190+
- generatedTokenCount: 13
191+
inputTokenCount: 1
192+
stopReason: MAX_TOKENS
193+
text: ' Once upon a time, there was a little boy named Tim.'
194+
195+
# Prompt loaded from peft adapter saved in a raw .bin file
196+
- name: Greedy with tuned peft adapter prefix in raw .bin format
197+
request:
198+
prefixId: tinyllama_peft_adapter_raw
199+
params:
200+
method: GREEDY
201+
stopping:
202+
maxNewTokens: 13
203+
requests:
204+
- {"text": ""}
205+
response:
206+
responses:
207+
- generatedTokenCount: 13
208+
inputTokenCount: 1
209+
stopReason: MAX_TOKENS
210+
text: ' Once upon a time, there was a little boy named Tim.'
211+
178212
# Prompt prefix with truncation
179213
- name: Greedy with tuned prompt prefix with truncation
180214
request:

server/tests/test_prompt_cache.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
it does LRU eviction in a thread safe way correctly.
33
"""
44
import gc
5+
import os
6+
from pathlib import Path
7+
58
import pytest
69
from unittest.mock import patch
710
import torch
11+
import safetensors.torch
812
from threading import Lock
913
from text_generation_server import prompt_cache
1014

@@ -14,6 +18,32 @@
1418
else:
1519
DEVICE = None
1620

21+
TESTS_DIR = os.path.dirname(__file__)
22+
REPO_ROOT = os.path.dirname(os.path.dirname(TESTS_DIR))
23+
INTEGRATION_TESTS_DIR = os.path.join(REPO_ROOT, "integration_tests")
24+
25+
26+
@pytest.fixture()
27+
def temp_prompt_store():
28+
with patch("text_generation_server.prompt_cache.PREFIX_STORE_PATH", Path(os.path.join(INTEGRATION_TESTS_DIR, "prompt_prefixes"))):
29+
yield
30+
31+
32+
@pytest.fixture()
33+
def tiny_starcoder_decoder_prompt(temp_prompt_store):
34+
return "tiny_starcoder"
35+
36+
37+
@pytest.fixture()
38+
def tiny_raw_llama_peft_adapter_prompt(temp_prompt_store):
39+
return "tinyllama_peft_adapter_raw"
40+
41+
42+
@pytest.fixture()
43+
def tiny_llama_peft_adapter_prompt(temp_prompt_store):
44+
return "tinyllama_peft_adapter"
45+
46+
1747
@pytest.fixture()
1848
def temp_prompt_cache_enc_dec_meta():
1949
"""Build an empty prompt cache for an encoder-decoder model with the 'meta'
@@ -285,11 +315,11 @@ def test_get_cache_len(mock_load_tensors, temp_prompt_cache):
285315

286316
### Test code paths for encoder decoder model
287317
# TODO: add more tests here!
288-
@patch("text_generation_server.prompt_cache.PrefixCache._load_embedding_tensor")
318+
@patch("text_generation_server.prompt_cache.PrefixCache._load_torch_file")
289319
def test_prompt_model_device_diff(mock_load, temp_prompt_cache_enc_dec_meta):
290320
# create prefix tensor on CPU which should be converted to the 'meta' device
291321
# before the decoder_start_tok_embedding is added to it
292-
mock_load.return_value = torch.ones((3,8), device='cpu')
322+
mock_load.return_value = torch.ones((4,8), device='cpu')
293323
temp_prompt_cache_enc_dec_meta.get("bad_prompt")
294324

295325
### Test cases for invalid prompts
@@ -360,3 +390,36 @@ def test_prompt_with_nan(mock_is_file, mock_torch_load, temp_prompt_cache):
360390
with pytest.raises(Exception):
361391
temp_prompt_cache.get("bad_prompt")
362392
assert len(temp_prompt_cache) == 0
393+
394+
395+
def test_prompt_cache_decoder_only_load(temp_prompt_cache, tiny_starcoder_decoder_prompt):
396+
"""Simple test that we can load a prompt with a decoder.pt file"""
397+
# The cache should load this without raising
398+
prompt = temp_prompt_cache.get(tiny_starcoder_decoder_prompt)
399+
400+
# Assert this is the same tensor that's in decoder.pt
401+
decoder_pt_path = os.path.join(prompt_cache.PREFIX_STORE_PATH, tiny_starcoder_decoder_prompt, "decoder.pt")
402+
decoder = torch.load(decoder_pt_path)
403+
assert decoder.equal(prompt)
404+
405+
406+
def test_prompt_cache_peft_decoder_load(temp_prompt_cache, tiny_raw_llama_peft_adapter_prompt):
407+
"""Simple test that we can load a prompt for a decoder-only model saved with PEFT directly in adapter_model.bin format"""
408+
# The cache should load this without raising
409+
prompt = temp_prompt_cache.get(tiny_raw_llama_peft_adapter_prompt)
410+
411+
# Assert this is the same tensor that's in adapter_model.bin
412+
adapter_model_path = os.path.join(prompt_cache.PREFIX_STORE_PATH, tiny_raw_llama_peft_adapter_prompt, "adapter_model.bin")
413+
adapter_model = torch.load(adapter_model_path, map_location=torch.device('cpu'))
414+
assert adapter_model["prompt_embeddings"].equal(prompt)
415+
416+
417+
def test_prompt_cache_safetensors_load(temp_prompt_cache, tiny_llama_peft_adapter_prompt):
418+
"""Simple test that we can load a prompt for a decoder-only model saved with PEFT directly in adapter_model.safetensors format"""
419+
# The cache should load this without raising
420+
prompt = temp_prompt_cache.get(tiny_llama_peft_adapter_prompt)
421+
422+
# Assert this is the same tensor that's in adapter_model.safetensors
423+
adapter_model_path = os.path.join(prompt_cache.PREFIX_STORE_PATH, tiny_llama_peft_adapter_prompt, "adapter_model.safetensors")
424+
adapter_model = safetensors.torch.load_file(adapter_model_path)
425+
assert adapter_model["prompt_embeddings"].equal(prompt)

server/text_generation_server/prompt_cache.py

Lines changed: 121 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import threading
77
from typing import Dict, List, Union, Tuple, Optional
8+
from safetensors.torch import load_file as safe_load_file
89

910
import torch
1011

@@ -191,11 +192,61 @@ def get(self, prefix_id: str) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.T
191192
cache_node = self._get_from_cache(prefix_id)
192193
if cache_node is None:
193194
# Release the lock & load the tensors
194-
prefix = self._load_embedding_tensors(prefix_id)
195+
self._reject_bad_prefix_ids(prefix_id)
196+
if self._is_peft_prefix(prefix_id):
197+
prefix = self._load_embedding_tensors_peft(prefix_id)
198+
else:
199+
prefix = self._load_embedding_tensors(prefix_id)
195200
# Relock & add the newly loaded tensor to the cache
196201
cache_node = self._add_prefix_id_to_cache(prefix_id, prefix)
197202
return cache_node.prompt
198203

204+
@staticmethod
205+
def _reject_bad_prefix_ids(prefix_id: str) -> None:
206+
"""Raises if the prefix does not exist, has an invalid name, or attempted to
207+
access files outside the prefix cache"""
208+
if not VALID_PREFIX_ID_PATTERN.fullmatch(prefix_id):
209+
raise Exception(f"Invalid prefix id {prefix_id}, must contain only alphanumeric, _ and - and /")
210+
prefix_dir_path = PREFIX_STORE_PATH / prefix_id
211+
# Check for path traversal
212+
if not os.path.normpath(prefix_dir_path).startswith(str(PREFIX_STORE_PATH) + "/"):
213+
raise Exception(f"Invalid prefix id {prefix_id}")
214+
215+
@staticmethod
216+
def _is_peft_prefix(prefix_id):
217+
"""Returns true if the prefix was saved with peft.save_pretrained()
218+
(has an adapter_model.bin file)"""
219+
prefix_dir_path = PREFIX_STORE_PATH / prefix_id
220+
if not os.path.isdir(prefix_dir_path):
221+
return False
222+
return "adapter_model" in [os.path.splitext(f)[0] for f in os.listdir(prefix_dir_path)]
223+
224+
def _load_embedding_tensors_peft(self, prefix_id: str) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
225+
"""Load prompt tensors for a peft adapter
226+
"""
227+
if self.is_encoder_decoder:
228+
raise Exception("encoder-decoder architectures not supported for peft models")
229+
230+
# safetensors is the default format, but users may have saved their model with
231+
# safe_serialization=False to produce the .bin file instead
232+
decoder_data_dict = self._load_torch_file(prefix_id, "adapter_model.safetensors")
233+
if decoder_data_dict is None:
234+
decoder_data_dict = self._load_torch_file(prefix_id, "adapter_model.bin")
235+
236+
if decoder_data_dict is None:
237+
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
238+
239+
# These files should contain dicts with a `prompt_embeddings` tensor
240+
decoder_data = decoder_data_dict["prompt_embeddings"]
241+
decoder_prefix = self._process_prefix_tensor(decoder_data, dtype=self.dtype)
242+
243+
if self.zero:
244+
# Return zero prefix early before sending tensor to gpu
245+
return self._zero_prefixes(decoder=decoder_prefix, encoder=None)
246+
247+
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
248+
return decoder_prefix
249+
199250
def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
200251
"""Load prompt tensors corresponding to a single prefix ID to disk. The return
201252
value of this function should be what is returned when indexing into the cache
@@ -209,63 +260,67 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
209260
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
210261
Loaded encoder / decoder prompt tensor for the model under consideration.
211262
"""
212-
decoder_prefix = self._load_embedding_tensor(prefix_id, "decoder.pt", dtype=self.dtype)
213-
# For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
214-
# at least one must be non-None
263+
decoder_data = self._load_torch_file(prefix_id, "decoder.pt")
264+
decoder_prefix = self._process_prefix_tensor(decoder_data, dtype=self.dtype)
265+
266+
encoder_data = self._load_torch_file(prefix_id, "encoder.pt")
267+
encoder_prefix = self._process_prefix_tensor(encoder_data, dtype=self.dtype)
268+
269+
if decoder_prefix is None and not self.is_encoder_decoder:
270+
# Must have a decoder for decoder only model
271+
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
272+
if decoder_prefix is None and encoder_prefix is None:
273+
# And either the decoder or encoder must be provided
274+
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
275+
276+
if self.zero:
277+
# Return zero prefixes early before sending tensors to gpu
278+
return self._zero_prefixes(encoder=encoder_prefix, decoder=decoder_prefix)
279+
215280
if decoder_prefix is not None:
216-
if self.zero is not None:
217-
decoder_prefix = self.zero.expand(decoder_prefix.shape)
218-
else:
219-
decoder_prefix = decoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
281+
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
220282

283+
# For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
221284
if self.is_encoder_decoder:
222-
encoder_prefix = self._load_embedding_tensor(prefix_id, "encoder.pt", dtype=self.dtype)
223-
if decoder_prefix is None:
224-
if encoder_prefix is None:
225-
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
226-
else:
285+
if decoder_prefix is not None:
227286
# TODO confirm this cat is correct
228-
if self.zero is not None:
229-
decoder_prefix = self.zero.expand(decoder_prefix.shape[0] + 1, *decoder_prefix.shape[1:])
230-
else:
231-
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
287+
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
232288
if encoder_prefix is not None:
233-
if self.zero is not None:
234-
encoder_prefix = self.zero.expand(encoder_prefix.shape)
235-
else:
236-
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
237-
prefix = encoder_prefix, decoder_prefix
238-
# For decoder-only we store just the decoder prefix
239-
elif decoder_prefix is None:
240-
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
289+
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
290+
291+
return encoder_prefix, decoder_prefix
292+
293+
return decoder_prefix
294+
295+
@staticmethod
296+
def _load_torch_file(prefix_id: str, filename: str) -> torch.Tensor | dict:
297+
"""Loads a file for the given prefix"""
298+
prefix_path = PREFIX_STORE_PATH / prefix_id / filename
299+
if not prefix_path.is_file():
300+
return None
301+
302+
logger.info(f"Loading new prefix {prefix_id}/{filename}")
303+
304+
if os.path.splitext(prefix_path)[1] == ".safetensors":
305+
return safe_load_file(prefix_path, device='cpu')
241306
else:
242-
prefix = decoder_prefix
243-
return prefix
307+
return torch.load(prefix_path, weights_only=True, map_location=torch.device('cpu'))
244308

245-
def _load_embedding_tensor(self, prefix_id: str, filename: str, dtype: torch.dtype) -> torch.Tensor:
246-
"""Load an embedding tensor from a single file.
309+
def _process_prefix_tensor(self, prefix: Optional[torch.Tensor], dtype: torch.dtype) -> Optional[torch.Tensor]:
310+
"""Convert a prefix tensor to the correct dtype and run some validation checks.
247311
248312
Args:
249-
prefix_id: str
250-
Name of the file that we want to load a torch tensor from.
251-
filename: str
252-
Name of the file to be loaded.
313+
prefix: torch.Tensor
314+
A prefix tensor loaded from a file.
315+
dtype: torch.dtype
316+
The desired dtype of the final prefix tensor.
253317
254318
Returns:
255319
torch.Tensor
256-
Tensor object corresponding to loaded prompt.
320+
A Tensor object corresponding to loaded prompt.
257321
"""
258-
if not VALID_PREFIX_ID_PATTERN.fullmatch(prefix_id):
259-
raise Exception(f"Invalid prefix id {prefix_id}, must contain only alphanumeric, _ and - and /")
260-
prefix_path = PREFIX_STORE_PATH / prefix_id / filename
261-
# Check for path traversal
262-
if not os.path.normpath(prefix_path).startswith(str(PREFIX_STORE_PATH) + "/"):
263-
raise Exception(f"Invalid prefix id {prefix_id}")
264-
if not prefix_path.is_file():
322+
if prefix is None:
265323
return None
266-
267-
logger.info(f"Loading new prefix {prefix_id}/{filename}")
268-
prefix = torch.load(prefix_path, weights_only=True, map_location=torch.device('cpu'))
269324
# Verify that it's a tensor of the correct shape
270325
if not torch.is_tensor(prefix) or len(prefix.shape) != 2:
271326
raise Exception(f"Invalid prefix embedding tensor")
@@ -290,6 +345,28 @@ def _load_embedding_tensor(self, prefix_id: str, filename: str, dtype: torch.dty
290345
converted_prefix.requires_grad = False
291346
return converted_prefix
292347

348+
def _zero_prefixes(
349+
self,
350+
encoder: Optional[torch.Tensor],
351+
decoder: Optional[torch.Tensor]
352+
) -> Optional[torch.Tensor] | Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
353+
"""If the return_zero flag is set, we replace the encoder and decoder prefixes
354+
with zero tensors instead"""
355+
if encoder is not None:
356+
encoder = self.zero.expand(encoder.shape)
357+
358+
if self.is_encoder_decoder:
359+
if decoder is not None:
360+
# For encoder-decoder models we need an extra column on the decoder to account for
361+
# the decoder_start_tok_embedding
362+
decoder = self.zero.expand(decoder.shape[0] + 1, *decoder.shape[1:])
363+
return encoder, decoder
364+
365+
if decoder is not None:
366+
decoder = self.zero.expand(decoder.shape)
367+
368+
return decoder
369+
293370
def _add_prefix_id_to_cache(
294371
self,
295372
prefix_id: str,

0 commit comments

Comments
 (0)