55import os
66import shutil
77import threading
8+ import uuid
89from dataclasses import asdict , dataclass , field
910from types import FunctionType
10- from typing import Dict , List , Optional , OrderedDict , Union
11+ from typing import Dict , Optional , Union
1112
1213import json
1314import numpy as np
@@ -176,11 +177,15 @@ def get_activated_adapters(self):
176177
177178class OffloadHelper :
178179
179- sub_dir = 'offload_cache'
180- cache_dir = os .path .join (get_cache_dir (), sub_dir )
181- shutil .rmtree (cache_dir , ignore_errors = True )
182- os .makedirs (cache_dir , exist_ok = True )
183- index = {}
180+ def __init__ (self ):
181+ sub_dir = os .path .join ('offload_cache' , str (uuid .uuid4 ().hex ))
182+ self .cache_dir = os .path .join (get_cache_dir (), sub_dir )
183+ shutil .rmtree (self .cache_dir , ignore_errors = True )
184+ os .makedirs (self .cache_dir , exist_ok = True )
185+ self .index = {}
186+
187+ def __del__ (self ):
188+ shutil .rmtree (self .cache_dir , ignore_errors = True )
184189
185190 @staticmethod
186191 def offload_weight (weight , weight_name , offload_folder , index = None ):
@@ -221,26 +226,24 @@ def load_offloaded_weight(weight_file, weight_info):
221226
222227 return weight
223228
224- @staticmethod
225- def offload_disk (module : torch .nn .Module , adapter_name , module_key ):
229+ def offload_disk (self , module : torch .nn .Module , adapter_name , module_key ):
226230 key = adapter_name + ':' + module_key
227231 md5 = hashlib .md5 (key .encode ('utf-8' )).hexdigest ()
228- sub_folder = os .path .join (OffloadHelper .cache_dir , md5 )
232+ sub_folder = os .path .join (self .cache_dir , md5 )
229233 os .makedirs (sub_folder , exist_ok = True )
230234 state_dict = module .state_dict ()
231- OffloadHelper .index [md5 ] = {}
235+ self .index [md5 ] = {}
232236 for key , tensor in state_dict .items ():
233- OffloadHelper .offload_weight (tensor , key , sub_folder , OffloadHelper .index [md5 ])
237+ OffloadHelper .offload_weight (tensor , key , sub_folder , self .index [md5 ])
234238
235- @staticmethod
236- def load_disk (module : torch .nn .Module , adapter_name , module_key ):
239+ def load_disk (self , module : torch .nn .Module , adapter_name , module_key ):
237240 key = adapter_name + ':' + module_key
238241 md5 = hashlib .md5 (key .encode ('utf-8' )).hexdigest ()
239- sub_folder = os .path .join (OffloadHelper .cache_dir , md5 )
242+ sub_folder = os .path .join (self .cache_dir , md5 )
240243 state_dict = {}
241- for key , value in OffloadHelper .index [md5 ].items ():
244+ for key , value in self .index [md5 ].items ():
242245 file = os .path .join (sub_folder , f'{ key } .dat' )
243- state_dict [key ] = OffloadHelper .load_offloaded_weight (file , OffloadHelper .index [md5 ][key ])
246+ state_dict [key ] = OffloadHelper .load_offloaded_weight (file , self .index [md5 ][key ])
244247 if version .parse (torch .__version__ ) >= version .parse ('2.1.0' ):
245248 module .load_state_dict (state_dict , assign = True )
246249 else :
@@ -264,6 +267,8 @@ def load_disk(module: torch.nn.Module, adapter_name, module_key):
264267
265268class SwiftAdapter :
266269
270+ offload_helper = OffloadHelper ()
271+
267272 @staticmethod
268273 def prepare_model (model : torch .nn .Module , config : SwiftConfig , adapter_name : str ) -> SwiftOutput :
269274 raise NotImplementedError
@@ -294,7 +299,7 @@ def offload(module: torch.nn.Module, adapter_name, module_key, offload: str):
294299 module .to ('cpu' )
295300 elif offload == 'meta' :
296301 if str (device ) != 'meta' :
297- OffloadHelper .offload_disk (module , adapter_name = adapter_name , module_key = module_key )
302+ SwiftAdapter . offload_helper .offload_disk (module , adapter_name = adapter_name , module_key = module_key )
298303 module .to ('meta' )
299304 else :
300305 raise NotImplementedError
@@ -309,7 +314,7 @@ def load(module: torch.nn.Module, adapter_name, module_key):
309314 module .to (module .origin_device )
310315 delattr (module , 'origin_device' )
311316 elif str (device ) == 'meta' :
312- OffloadHelper .load_disk (module , adapter_name = adapter_name , module_key = module_key )
317+ SwiftAdapter . offload_helper .load_disk (module , adapter_name = adapter_name , module_key = module_key )
313318 module .to (module .origin_device )
314319 delattr (module , 'origin_device' )
315320
0 commit comments