99import ast
1010import paddle
1111
12+ kLiteralTensorSize = 64
13+
1214
1315def get_limited_precision_float_str (value ):
1416 if not isinstance (value , float ):
@@ -35,15 +37,15 @@ def process_tensor(tensor):
3537
3638 info = tensor_info (tensor )
3739 if tensor .dtype in [paddle .int8 , paddle .int16 , paddle .int32 , paddle .int64 ]:
38- if tensor .numel () < 1024 :
40+ if tensor .numel () < kLiteralTensorSize :
3941 return {
4042 "type" : "small_int_tensor" ,
4143 "data" : tensor .clone (),
4244 "info" : info ,
4345 }
4446 else :
4547 return {"type" : "big_int_tensor" , "data" : tensor .clone (), "info" : info }
46- elif tensor .numel () < 1024 :
48+ elif tensor .numel () < kLiteralTensorSize :
4749 return {"type" : "small_tensor" , "data" : tensor .clone (), "info" : info }
4850 else :
4951 return {"type" : "random_tensor" , "info" : info }
@@ -141,10 +143,10 @@ def convert_meta_classes_to_tensors(file_path):
141143 "shape" : attrs .get ("shape" , []),
142144 "dtype" : data_type ,
143145 "device" : attrs .get ("device" , "gpu" ),
144- "mean" : attrs .get ("mean" , 0.0 ),
145- "std" : attrs .get ("std" , 1.0 ),
146- "low " : attrs .get ("low " , 0 ),
147- "high " : attrs .get ("high " , 2 ),
146+ "mean" : 0.0 if attrs .get ("mean" , None ) is None else attrs . get ( "mean" ),
147+ "std" : 1.0 if attrs .get ("std" , None ) is None else attrs . get ( "std" ),
148+ "min_val " : attrs .get ("min_val " , 0 ),
149+ "max_val " : attrs .get ("max_val " , 2 ),
148150 },
149151 "data" : data_value ,
150152 "name" : attrs .get ("name" ),
@@ -173,17 +175,18 @@ def replay_tensor(info):
173175 device = info ["info" ]["device" ]
174176 dtype = info ["info" ]["dtype" ]
175177 shape = info ["info" ]["shape" ]
176- min_value = info ["info" ]["low" ] if "low" in info ["info" ] else 0
177- max_value = info ["info" ]["high" ] if "high" in info ["info" ] else 0.5
178+
179+ mean = info ["info" ]["mean" ]
180+ std = info ["info" ]["std" ]
181+ min_val = info ["info" ]["min_val" ]
182+ max_val = info ["info" ]["max_val" ]
178183 if None in shape :
179184 shape = list (map (lambda i : i if i is not None else 1 , shape ))
180185 if "data" in info and info ["data" ] is not None :
181186 return paddle .reshape (info ["data" ], shape ).to (dtype ).to (device )
182187 elif dtype == paddle .int32 or dtype == paddle .int64 :
183188 return paddle .cast (
184- paddle .randint (
185- low = min_value , high = max_value + 1 , shape = shape , dtype = "int64"
186- ),
189+ paddle .randint (low = min_val , high = max_val + 1 , shape = shape , dtype = "int64" ),
187190 dtype ,
188191 ).to (device )
189192 elif dtype == paddle .bool :
@@ -194,7 +197,9 @@ def replay_tensor(info):
194197 else :
195198 std = info ["info" ]["std" ]
196199 return (
197- paddle .uniform (shape , dtype = "float32" , min = min_value , max = max_value )
200+ paddle .clip (
201+ paddle .normal (shape = shape , mean = mean , std = std ), min = min_val , max = max_val
202+ )
198203 .to (dtype )
199204 .to (device )
200205 )
0 commit comments