5
5
import re
6
6
import threading
7
7
from typing import Dict , List , Union , Tuple , Optional
8
+ from safetensors .torch import load_file as safe_load_file
8
9
9
10
import torch
10
11
@@ -191,11 +192,61 @@ def get(self, prefix_id: str) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.T
191
192
cache_node = self ._get_from_cache (prefix_id )
192
193
if cache_node is None :
193
194
# Release the lock & load the tensors
194
- prefix = self ._load_embedding_tensors (prefix_id )
195
+ self ._reject_bad_prefix_ids (prefix_id )
196
+ if self ._is_peft_prefix (prefix_id ):
197
+ prefix = self ._load_embedding_tensors_peft (prefix_id )
198
+ else :
199
+ prefix = self ._load_embedding_tensors (prefix_id )
195
200
# Relock & add the newly loaded tensor to the cache
196
201
cache_node = self ._add_prefix_id_to_cache (prefix_id , prefix )
197
202
return cache_node .prompt
198
203
204
+ @staticmethod
205
+ def _reject_bad_prefix_ids (prefix_id : str ) -> None :
206
+ """Raises if the prefix does not exist, has an invalid name, or attempted to
207
+ access files outside the prefix cache"""
208
+ if not VALID_PREFIX_ID_PATTERN .fullmatch (prefix_id ):
209
+ raise Exception (f"Invalid prefix id { prefix_id } , must contain only alphanumeric, _ and - and /" )
210
+ prefix_dir_path = PREFIX_STORE_PATH / prefix_id
211
+ # Check for path traversal
212
+ if not os .path .normpath (prefix_dir_path ).startswith (str (PREFIX_STORE_PATH ) + "/" ):
213
+ raise Exception (f"Invalid prefix id { prefix_id } " )
214
+
215
+ @staticmethod
216
+ def _is_peft_prefix (prefix_id ):
217
+ """Returns true if the prefix was saved with peft.save_pretrained()
218
+ (has an adapter_model.bin file)"""
219
+ prefix_dir_path = PREFIX_STORE_PATH / prefix_id
220
+ if not os .path .isdir (prefix_dir_path ):
221
+ return False
222
+ return "adapter_model" in [os .path .splitext (f )[0 ] for f in os .listdir (prefix_dir_path )]
223
+
224
+ def _load_embedding_tensors_peft (self , prefix_id : str ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
225
+ """Load prompt tensors for a peft adapter
226
+ """
227
+ if self .is_encoder_decoder :
228
+ raise Exception ("encoder-decoder architectures not supported for peft models" )
229
+
230
+ # safetensors is the default format, but users may have saved their model with
231
+ # safe_serialization=False to produce the .bin file instead
232
+ decoder_data_dict = self ._load_torch_file (prefix_id , "adapter_model.safetensors" )
233
+ if decoder_data_dict is None :
234
+ decoder_data_dict = self ._load_torch_file (prefix_id , "adapter_model.bin" )
235
+
236
+ if decoder_data_dict is None :
237
+ raise PrefixNotFound (f"Prefix id { prefix_id } not found" )
238
+
239
+ # These files should contain dicts with a `prompt_embeddings` tensor
240
+ decoder_data = decoder_data_dict ["prompt_embeddings" ]
241
+ decoder_prefix = self ._process_prefix_tensor (decoder_data , dtype = self .dtype )
242
+
243
+ if self .zero :
244
+ # Return zero prefix early before sending tensor to gpu
245
+ return self ._zero_prefixes (decoder = decoder_prefix , encoder = None )
246
+
247
+ decoder_prefix = decoder_prefix .to (self .device , non_blocking = True )
248
+ return decoder_prefix
249
+
199
250
def _load_embedding_tensors (self , prefix_id : str ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
200
251
"""Load prompt tensors corresponding to a single prefix ID to disk. The return
201
252
value of this function should be what is returned when indexing into the cache
@@ -209,63 +260,67 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
209
260
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
210
261
Loaded encoder / decoder prompt tensor for the model under consideration.
211
262
"""
212
- decoder_prefix = self ._load_embedding_tensor (prefix_id , "decoder.pt" , dtype = self .dtype )
213
- # For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
214
- # at least one must be non-None
263
+ decoder_data = self ._load_torch_file (prefix_id , "decoder.pt" )
264
+ decoder_prefix = self ._process_prefix_tensor (decoder_data , dtype = self .dtype )
265
+
266
+ encoder_data = self ._load_torch_file (prefix_id , "encoder.pt" )
267
+ encoder_prefix = self ._process_prefix_tensor (encoder_data , dtype = self .dtype )
268
+
269
+ if decoder_prefix is None and not self .is_encoder_decoder :
270
+ # Must have a decoder for decoder only model
271
+ raise PrefixNotFound (f"Prefix id { prefix_id } not found" )
272
+ if decoder_prefix is None and encoder_prefix is None :
273
+ # And either the decoder or encoder must be provided
274
+ raise PrefixNotFound (f"Prefix id { prefix_id } not found" )
275
+
276
+ if self .zero :
277
+ # Return zero prefixes early before sending tensors to gpu
278
+ return self ._zero_prefixes (encoder = encoder_prefix , decoder = decoder_prefix )
279
+
215
280
if decoder_prefix is not None :
216
- if self .zero is not None :
217
- decoder_prefix = self .zero .expand (decoder_prefix .shape )
218
- else :
219
- decoder_prefix = decoder_prefix .to (self .dtype ).to (self .device , non_blocking = True )
281
+ decoder_prefix = decoder_prefix .to (self .device , non_blocking = True )
220
282
283
+ # For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
221
284
if self .is_encoder_decoder :
222
- encoder_prefix = self ._load_embedding_tensor (prefix_id , "encoder.pt" , dtype = self .dtype )
223
- if decoder_prefix is None :
224
- if encoder_prefix is None :
225
- raise PrefixNotFound (f"Prefix id { prefix_id } not found" )
226
- else :
285
+ if decoder_prefix is not None :
227
286
# TODO confirm this cat is correct
228
- if self .zero is not None :
229
- decoder_prefix = self .zero .expand (decoder_prefix .shape [0 ] + 1 , * decoder_prefix .shape [1 :])
230
- else :
231
- decoder_prefix = torch .cat ((decoder_prefix , self .decoder_start_tok_embedding ))
287
+ decoder_prefix = torch .cat ((decoder_prefix , self .decoder_start_tok_embedding ))
232
288
if encoder_prefix is not None :
233
- if self .zero is not None :
234
- encoder_prefix = self .zero .expand (encoder_prefix .shape )
235
- else :
236
- encoder_prefix = encoder_prefix .to (self .device , non_blocking = True )
237
- prefix = encoder_prefix , decoder_prefix
238
- # For decoder-only we store just the decoder prefix
239
- elif decoder_prefix is None :
240
- raise PrefixNotFound (f"Prefix id { prefix_id } not found" )
289
+ encoder_prefix = encoder_prefix .to (self .device , non_blocking = True )
290
+
291
+ return encoder_prefix , decoder_prefix
292
+
293
+ return decoder_prefix
294
+
295
+ @staticmethod
296
+ def _load_torch_file (prefix_id : str , filename : str ) -> torch .Tensor | dict :
297
+ """Loads a file for the given prefix"""
298
+ prefix_path = PREFIX_STORE_PATH / prefix_id / filename
299
+ if not prefix_path .is_file ():
300
+ return None
301
+
302
+ logger .info (f"Loading new prefix { prefix_id } /{ filename } " )
303
+
304
+ if os .path .splitext (prefix_path )[1 ] == ".safetensors" :
305
+ return safe_load_file (prefix_path , device = 'cpu' )
241
306
else :
242
- prefix = decoder_prefix
243
- return prefix
307
+ return torch .load (prefix_path , weights_only = True , map_location = torch .device ('cpu' ))
244
308
245
- def _load_embedding_tensor (self , prefix_id : str , filename : str , dtype : torch .dtype ) -> torch .Tensor :
246
- """Load an embedding tensor from a single file .
309
+ def _process_prefix_tensor (self , prefix : Optional [ torch . Tensor ], dtype : torch .dtype ) -> Optional [ torch .Tensor ] :
310
+ """Convert a prefix tensor to the correct dtype and run some validation checks .
247
311
248
312
Args:
249
- prefix_id: str
250
- Name of the file that we want to load a torch tensor from .
251
- filename: str
252
- Name of the file to be loaded .
313
+ prefix: torch.Tensor
314
+ A prefix tensor loaded from a file .
315
+ dtype: torch.dtype
316
+ The desired dtype of the final prefix tensor .
253
317
254
318
Returns:
255
319
torch.Tensor
256
- Tensor object corresponding to loaded prompt.
320
+ A Tensor object corresponding to loaded prompt.
257
321
"""
258
- if not VALID_PREFIX_ID_PATTERN .fullmatch (prefix_id ):
259
- raise Exception (f"Invalid prefix id { prefix_id } , must contain only alphanumeric, _ and - and /" )
260
- prefix_path = PREFIX_STORE_PATH / prefix_id / filename
261
- # Check for path traversal
262
- if not os .path .normpath (prefix_path ).startswith (str (PREFIX_STORE_PATH ) + "/" ):
263
- raise Exception (f"Invalid prefix id { prefix_id } " )
264
- if not prefix_path .is_file ():
322
+ if prefix is None :
265
323
return None
266
-
267
- logger .info (f"Loading new prefix { prefix_id } /{ filename } " )
268
- prefix = torch .load (prefix_path , weights_only = True , map_location = torch .device ('cpu' ))
269
324
# Verify that it's a tensor of the correct shape
270
325
if not torch .is_tensor (prefix ) or len (prefix .shape ) != 2 :
271
326
raise Exception (f"Invalid prefix embedding tensor" )
@@ -290,6 +345,28 @@ def _load_embedding_tensor(self, prefix_id: str, filename: str, dtype: torch.dty
290
345
converted_prefix .requires_grad = False
291
346
return converted_prefix
292
347
348
+ def _zero_prefixes (
349
+ self ,
350
+ encoder : Optional [torch .Tensor ],
351
+ decoder : Optional [torch .Tensor ]
352
+ ) -> Optional [torch .Tensor ] | Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]:
353
+ """If the return_zero flag is set, we replace the encoder and decoder prefixes
354
+ with zero tensors instead"""
355
+ if encoder is not None :
356
+ encoder = self .zero .expand (encoder .shape )
357
+
358
+ if self .is_encoder_decoder :
359
+ if decoder is not None :
360
+ # For encoder-decoder models we need an extra column on the decoder to account for
361
+ # the decoder_start_tok_embedding
362
+ decoder = self .zero .expand (decoder .shape [0 ] + 1 , * decoder .shape [1 :])
363
+ return encoder , decoder
364
+
365
+ if decoder is not None :
366
+ decoder = self .zero .expand (decoder .shape )
367
+
368
+ return decoder
369
+
293
370
def _add_prefix_id_to_cache (
294
371
self ,
295
372
prefix_id : str ,
0 commit comments