@@ -59,7 +59,12 @@ def process_tensor(tensor):
59
59
"info" : info ,
60
60
}
61
61
else :
62
- return {"type" : "big_int_tensor" , "data" : tensor .clone (), "info" : info }
62
+ return {
63
+ "type" : "big_int_tensor_by_range" ,
64
+ "min_val" : tensor .min ().item (),
65
+ "max_val" : tensor .max ().item (),
66
+ "info" : info ,
67
+ }
63
68
elif tensor .numel () < 1024 :
64
69
return {"type" : "small_tensor" , "data" : tensor .clone (), "info" : info }
65
70
else :
@@ -73,16 +78,25 @@ def process_tensor(tensor):
73
78
processed_inputs = {"type" : "unknown" , "value" : example_inputs }
74
79
75
80
def handle_named_tensors (tensor ):
76
- data_value = None
77
- data_type = "random_tensor"
81
+ info = tensor_info (tensor )
78
82
if tensor .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]:
79
83
if tensor .numel () < 1024 :
80
- data_type = "small_int_tensor"
81
- data_value = tensor .clone ()
84
+ return {
85
+ "info" : info ,
86
+ "data" : tensor .clone (),
87
+ "type" : "small_int_tensor" ,
88
+ }
82
89
else :
83
- data_type = "big_int_tensor"
84
- info = tensor_info (tensor )
85
- return {"info" : info , "data" : data_value , "type" : data_type }
90
+ return {
91
+ "info" : info ,
92
+ "min_val" : tensor .min ().item (),
93
+ "max_val" : tensor .max ().item (),
94
+ "type" : "big_int_tensor_by_range" ,
95
+ }
96
+ if tensor .numel () < 1024 :
97
+ return {"info" : info , "data" : tensor .clone (), "type" : "small_tensor" }
98
+ else :
99
+ return {"info" : info , "data" : None , "type" : "random_tensor" }
86
100
87
101
processed_weights = {
88
102
key : handle_named_tensors (tensor ) for key , tensor in state_dict .items ()
@@ -114,46 +128,46 @@ def format_data(data):
114
128
return "None"
115
129
elif isinstance (data , torch .Tensor ):
116
130
if data .dtype .is_floating_point :
117
- return "[{}]" .format (", " .join (f"{ x :.6f} " for x in data .tolist ()))
131
+ return "[{}]" .format (
132
+ ", " .join (f"{ x :.6f} " for x in data .flatten ().tolist ())
133
+ )
118
134
else :
119
- return "[{}]" .format (", " .join (f"{ x } " for x in data .tolist ()))
135
+ return "[{}]" .format (", " .join (f"{ x } " for x in data .flatten (). tolist ()))
120
136
else :
121
137
return repr (data )
122
138
123
139
def process_tensor_info (tensor_info , name_prefix = "example_input" ):
124
- data_list = None
125
- if "input_" in tensor_info ["name" ]:
126
- if tensor_info ["type" ] in ["small_tensor" , "small_int_tensor" ]:
127
- data_list = tensor_info ["data" ].flatten ()
128
- elif tensor_info ["type" ] == "big_int_tensor" :
129
- data_list = f"pt-filename:xxx-key"
130
- else :
131
- pass
132
- else :
133
- if tensor_info ["type" ] == "small_int_tensor" :
134
- data_list = tensor_info ["data" ].flatten ()
135
- if tensor_info ["type" ] == "big_int_tensor" :
136
- raise ValueError (
137
- "Unexpected cases: there are weights in big tensor of int type "
138
- )
140
+ tensor_type = tensor_info .get ("type" )
139
141
info = tensor_info .get ("info" , {})
140
142
dtype = info .get ("dtype" , "torch.float" )
141
143
shape = info .get ("shape" , [])
142
144
device = info .get ("device" , "cpu" )
143
145
mean = info .get ("mean" , 0.0 )
144
146
std = info .get ("std" , 1.0 )
145
147
uid = f"{ name_prefix } _tensor_meta_{ tensor_info .get ('name' , '' )} "
146
- return [
148
+
149
+ lines = [
147
150
(f"class { uid } :" ),
148
151
(f"\t name = \" { tensor_info .get ('name' , '' )} \" " ),
149
152
(f"\t shape = { shape } " ),
150
153
(f'\t dtype = "{ dtype } "' ),
151
154
(f'\t device = "{ device } "' ),
152
155
(f"\t mean = { get_limited_precision_float_str (mean )} " ),
153
156
(f"\t std = { get_limited_precision_float_str (std )} " ),
154
- (f"\t data = { format_data (data_list )} " ),
155
- ("" ),
156
157
]
158
+ if tensor_type == "big_int_tensor_by_range" :
159
+ lines .append (f"\t min_val = { tensor_info ['min_val' ]} " )
160
+ lines .append (f"\t max_val = { tensor_info ['max_val' ]} " )
161
+ elif "data" in tensor_info :
162
+ data_list = (
163
+ tensor_info ["data" ].flatten ()
164
+ if isinstance (tensor_info ["data" ], torch .Tensor )
165
+ else tensor_info ["data" ]
166
+ )
167
+ lines .append (f"\t data = { format_data (data_list )} " )
168
+
169
+ lines .append ("" )
170
+ return lines
157
171
158
172
input_infos = converted ["input_info" ]
159
173
if isinstance (input_infos , dict ):
@@ -202,7 +216,16 @@ def convert_meta_classes_to_tensors(file_path):
202
216
}
203
217
data_value = None
204
218
data_type = getattr (torch , attrs .get ("dtype" , "torch.float" ).split ("." )[- 1 ])
205
- if attrs .get ("data" ) is not None :
219
+ shape = attrs .get ("shape" , [])
220
+
221
+ if "min_val" in attrs and "max_val" in attrs :
222
+ min_val = attrs ["min_val" ]
223
+ max_val = attrs ["max_val" ]
224
+ # torch.randint's upper bound is exclusive, so add 1
225
+ data_value = torch .randint (
226
+ min_val , max_val + 1 , size = shape , dtype = data_type
227
+ )
228
+ elif attrs .get ("data" ) is not None :
206
229
if isinstance (attrs .get ("data" ), str ):
207
230
raise ValueError ("Unimplemented" )
208
231
else :
0 commit comments