88import inspect
99import ast
1010import math
11+ import numpy as np
1112import paddle
1213
1314kLiteralTensorSize = 64
@@ -139,6 +140,7 @@ def convert_to_valid_number(data_type, value):
139140
140141
141142def convert_meta_classes_to_tensors (file_path ):
143+ current_device = paddle .device .get_device ()
142144 for name , cls in _get_classes (file_path ):
143145 attrs = {
144146 k : v
@@ -159,7 +161,7 @@ def convert_meta_classes_to_tensors(file_path):
159161 "info" : {
160162 "shape" : attrs .get ("shape" , []),
161163 "dtype" : data_type ,
162- "device" : attrs .get ("device" , "gpu" ),
164+ "device" : attrs .get ("device" , current_device ),
163165 "mean" : convert_to_valid_number (data_type , attrs .get ("mean" , None )),
164166 "std" : convert_to_valid_number (data_type , attrs .get ("std" , None )),
165167 "min_val" : convert_to_valid_number (data_type , attrs .get ("min_val" , 0 )),
@@ -188,7 +190,43 @@ def extract_dynamic_shapes(example_inputs):
188190 pass
189191
190192
191- def replay_tensor (info ):
193+ def init_integer_tensor (dtype , shape , min_val , max_val , use_numpy ):
194+ if use_numpy :
195+ array = np .random .randint (
196+ low = min_val , high = max_val + 1 , size = shape , dtype = dtype
197+ )
198+ return paddle .to_tensor (array )
199+ else :
200+ return paddle .randint (low = min_val , high = max_val + 1 , shape = shape , dtype = dtype )
201+
202+
203+ def init_float_tensor (shape , mean , std , min_val , max_val , use_numpy ):
204+ tensor = None
205+ if use_numpy :
206+ if mean is not None and std is not None :
207+ array = np .random .normal (mean , std , shape )
208+ mask = (array < min_val ) | (array > max_val )
209+ while np .any (mask ):
210+ array [mask ] = np .random .normal (mean , std , mask .sum ())
211+ mask = (array < min_val ) | (array > max_val )
212+ else :
213+ array = np .random .uniform (low = min_val , high = max_val , size = shape )
214+ tensor = paddle .to_tensor (array )
215+ else :
216+ if mean is not None and std is not None :
217+ tensor = paddle .empty (shape = shape , dtype = "float32" )
218+ initializer = paddle .nn .initializer .TruncatedNormal (
219+ mean = mean , std = std , a = min_val , b = max_val
220+ )
221+ initializer (tensor )
222+ else :
223+ tensor = paddle .uniform (
224+ shape = shape , dtype = "float32" , min = min_val , max = max_val
225+ )
226+ return tensor
227+
228+
229+ def replay_tensor (info , use_numpy = True ):
192230 device = info ["info" ]["device" ]
193231 dtype = info ["info" ]["dtype" ]
194232 shape = info ["info" ]["shape" ]
@@ -201,27 +239,14 @@ def replay_tensor(info):
201239 shape = list (map (lambda i : i if i is not None else 1 , shape ))
202240 if "data" in info and info ["data" ] is not None :
203241 return paddle .reshape (info ["data" ], shape ).to (dtype ).to (device )
204- elif dtype == paddle .int32 or dtype == paddle .int64 :
205- return paddle .cast (
206- paddle .randint (low = min_val , high = max_val + 1 , shape = shape , dtype = "int64" ),
207- dtype ,
208- ).to (device )
209- elif dtype == paddle .bool :
210- return paddle .cast (
211- paddle .randint (low = 0 , high = 2 , shape = shape , dtype = "int32" ),
212- paddle .bool ,
213- ).to (device )
242+ elif dtype in [paddle .int32 , paddle .int64 , paddle .bool ]:
243+ init_dtype = "int32" if dtype == paddle .bool else "int64"
244+ min_val , max_val = 0 , 1 if dtype == paddle .bool else min_val , max_val
245+ return (
246+ init_integer_tensor (init_dtype , shape , min_val , max_val , use_numpy )
247+ .to (dtype )
248+ .to (device )
249+ )
214250 else :
215- if mean is not None and std is not None :
216- tensor = paddle .empty (shape = shape , dtype = "float32" )
217- initializer = paddle .nn .initializer .TruncatedNormal (
218- mean = mean , std = std , a = min_val , b = max_val
219- )
220- initializer (tensor )
221- return tensor .to (dtype ).to (device )
222- else :
223- return (
224- paddle .uniform (shape = shape , dtype = "float32" , min = min_val , max = max_val )
225- .to (dtype )
226- .to (device )
227- )
251+ tensor = init_float_tensor (shape , mean , std , min_val , max_val , use_numpy )
252+ return tensor .to (dtype ).to (device )
0 commit comments