@@ -57,7 +57,13 @@ def process_tensor(tensor):
5757 "info" : info ,
5858 }
5959 else :
60- return {"type" : "big_int_tensor" , "data" : tensor .clone (), "info" : info }
60+ sparse_tensor = tensor .to_sparse_coo ()
61+ return {
62+ "type" : "sparse_int_tensor" ,
63+ "indices" : sparse_tensor .indices ().clone (),
64+ "values" : sparse_tensor .values ().clone (),
65+ "info" : info ,
66+ }
6167 elif tensor .numel () < 1024 :
6268 return {"type" : "small_tensor" , "data" : tensor .clone (), "info" : info }
6369 else :
@@ -78,15 +84,20 @@ def handle_named_tensors(tensor):
7884 data_type = "small_int_tensor"
7985 data_value = tensor .clone ()
8086 else :
81- data_type = "big_int_tensor"
87+ data_type = "sparse_int_tensor"
88+ sparse_tensor = tensor .to_sparse_coo ()
89+ data_value = {
90+ "indices" : sparse_tensor .indices ().clone (),
91+ "values" : sparse_tensor .values ().clone (),
92+ }
93+
8294 info = tensor_info (tensor )
8395 return {"info" : info , "data" : data_value , "type" : data_type }
8496
8597 processed_weights = {
8698 key : handle_named_tensors (tensor ) for key , tensor in state_dict .items ()
8799 }
88100
89- # dynamic_shapes = extract_dynamic_shapes(example_inputs)
90101 return {
91102 "input_info" : processed_inputs ,
92103 "weight_info" : processed_weights ,
@@ -112,46 +123,59 @@ def format_data(data):
112123 return "None"
113124 elif isinstance (data , torch .Tensor ):
114125 if data .dtype .is_floating_point :
115- return "[{}]" .format (", " .join (f"{ x :.6f} " for x in data .tolist ()))
126+ return "[{}]" .format (", " .join (f"{ x :.6f} " for x in data .flatten (). tolist ()))
116127 else :
117- return "[{}]" .format (", " .join (f"{ x } " for x in data .tolist ()))
128+ return "[{}]" .format (", " .join (f"{ x } " for x in data .flatten (). tolist ()))
118129 else :
119130 return repr (data )
120131
121132 def process_tensor_info (tensor_info , name_prefix = "example_input" ):
122133 data_list = None
123- if "input_" in tensor_info ["name" ]:
134+ # MODIFICATION: Handle sparse tensor serialization
135+ is_sparse = tensor_info .get ("type" ) == "sparse_int_tensor"
136+ sparse_indices = None
137+ sparse_values = None
138+
139+ if is_sparse :
140+ data_list = None # No dense data for sparse tensors
141+ sparse_indices = tensor_info ["data" ]["indices" ]
142+ sparse_values = tensor_info ["data" ]["values" ]
143+ elif "input_" in tensor_info ["name" ]:
124144 if tensor_info ["type" ] in ["small_tensor" , "small_int_tensor" ]:
125145 data_list = tensor_info ["data" ].flatten ()
126- elif tensor_info ["type" ] == "big_int_tensor" :
127- data_list = f"pt-filename:xxx-key"
128146 else :
129147 pass
130148 else :
131- if tensor_info ["type" ] == "small_int_tensor" :
149+ if tensor_info ["type" ] == "small_int_tensor" :
132150 data_list = tensor_info ["data" ].flatten ()
133- if tensor_info ["type" ] == "big_int_tensor" :
134- raise ValueError (
135- "Unexpected cases: there are weights in big tensor of int type "
136- )
151+
137152 info = tensor_info .get ("info" , {})
138153 dtype = info .get ("dtype" , "torch.float" )
139154 shape = info .get ("shape" , [])
140155 device = info .get ("device" , "cpu" )
141156 mean = info .get ("mean" , 0.0 )
142157 std = info .get ("std" , 1.0 )
143158 uid = f"{ name_prefix } _tensor_meta_{ tensor_info .get ('name' , '' )} "
144- return [
159+
160+ lines = [
145161 (f"class { uid } :" ),
146162 (f"\t name = \" { tensor_info .get ('name' , '' )} \" " ),
147163 (f"\t shape = { shape } " ),
148164 (f'\t dtype = "{ dtype } "' ),
149165 (f'\t device = "{ device } "' ),
150166 (f"\t mean = { get_limited_precision_float_str (mean )} " ),
151167 (f"\t std = { get_limited_precision_float_str (std )} " ),
152- (f"\t data = { format_data (data_list )} " ),
153- ("" ),
154168 ]
169+ if is_sparse :
170+ lines .append (f"\t is_sparse = True" )
171+ lines .append (f"\t indices = { format_data (sparse_indices )} " )
172+ lines .append (f"\t values = { format_data (sparse_values )} " )
173+ else :
174+ lines .append (f"\t data = { format_data (data_list )} " )
175+
176+ lines .append ("" )
177+ return lines
178+
155179
156180 input_infos = converted ["input_info" ]
157181 if isinstance (input_infos , dict ):
@@ -200,13 +224,22 @@ def convert_meta_classes_to_tensors(file_path):
200224 }
201225 data_value = None
202226 data_type = getattr (torch , attrs .get ("dtype" , "torch.float" ).split ("." )[- 1 ])
203- if attrs .get ("data" ) is not None :
227+
228+ # MODIFICATION: Reconstruct sparse tensors during loading
229+ if attrs .get ("is_sparse" ):
230+ indices_shape = (len (attrs .get ("shape" )), - 1 )
231+ indices = torch .tensor (attrs ["indices" ]).reshape (indices_shape )
232+ values = torch .tensor (attrs ["values" ], dtype = data_type )
233+ shape = attrs .get ("shape" )
234+ data_value = torch .sparse_coo_tensor (indices , values , shape ).to_dense ()
235+ elif attrs .get ("data" ) is not None :
204236 if isinstance (attrs .get ("data" ), str ):
205237 raise ValueError ("Unimplemented" )
206238 else :
207239 data_value = torch .tensor (attrs ["data" ], dtype = data_type ).reshape (
208- attrs .get ("shape" ) , []
240+ attrs .get ("shape" , [])
209241 )
242+
210243 yield {
211244 "info" : {
212245 "shape" : attrs .get ("shape" , []),
@@ -240,4 +273,4 @@ def replay_tensor(info):
240273 std = info ["info" ]["std" ]
241274 if "data" in info and info ["data" ] is not None :
242275 return info ["data" ].to (device )
243- return torch .randn (size = shape ).to (dtype ).to (device ) * std * 0.2 + mean
276+ return torch .randn (size = shape ).to (dtype ).to (device ) * std * 0.2 + mean
0 commit comments