@@ -124,7 +124,9 @@ def format_data(data):
124124 return "None"
125125 elif isinstance (data , torch .Tensor ):
126126 if data .dtype .is_floating_point :
127- return "[{}]" .format (", " .join (f"{ x :.6f} " for x in data .flatten ().tolist ()))
127+ return "[{}]" .format (
128+ ", " .join (f"{ x :.6f} " for x in data .flatten ().tolist ())
129+ )
128130 else :
129131 return "[{}]" .format (", " .join (f"{ x } " for x in data .flatten ().tolist ()))
130132 else :
@@ -137,7 +139,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
137139 sparse_values = None
138140
139141 if is_sparse :
140- data_list = None # No dense data for sparse tensors
142+ data_list = None # No dense data for sparse tensors
141143 sparse_indices = tensor_info ["data" ]["indices" ]
142144 sparse_values = tensor_info ["data" ]["values" ]
143145 elif "input_" in tensor_info ["name" ]:
@@ -146,7 +148,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
146148 else :
147149 pass
148150 else :
149- if tensor_info ["type" ] == "small_int_tensor" :
151+ if tensor_info ["type" ] == "small_int_tensor" :
150152 data_list = tensor_info ["data" ].flatten ()
151153
152154 info = tensor_info .get ("info" , {})
@@ -156,7 +158,7 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
156158 mean = info .get ("mean" , 0.0 )
157159 std = info .get ("std" , 1.0 )
158160 uid = f"{ name_prefix } _tensor_meta_{ tensor_info .get ('name' , '' )} "
159-
161+
160162 lines = [
161163 (f"class { uid } :" ),
162164 (f"\t name = \" { tensor_info .get ('name' , '' )} \" " ),
@@ -172,11 +174,10 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
172174 lines .append (f"\t values = { format_data (sparse_values )} " )
173175 else :
174176 lines .append (f"\t data = { format_data (data_list )} " )
175-
177+
176178 lines .append ("" )
177179 return lines
178180
179-
180181 input_infos = converted ["input_info" ]
181182 if isinstance (input_infos , dict ):
182183 input_infos = [input_infos ]
@@ -224,7 +225,7 @@ def convert_meta_classes_to_tensors(file_path):
224225 }
225226 data_value = None
226227 data_type = getattr (torch , attrs .get ("dtype" , "torch.float" ).split ("." )[- 1 ])
227-
228+
228229 if attrs .get ("is_sparse" ):
229230 indices_shape = (len (attrs .get ("shape" )), - 1 )
230231 indices = torch .tensor (attrs ["indices" ]).reshape (indices_shape )
0 commit comments