@@ -127,21 +127,24 @@ def convert_meta_classes_to_tensors(file_path):
127127 if not k .startswith ("__" ) and not callable (v )
128128 }
129129 data_value = None
130- data_type = getattr (paddle , attrs .get ("dtype" , "paddle.float" ). split ( "." )[ - 1 ] )
130+ data_type = getattr (paddle , attrs .get ("dtype" , "float32" ) )
131131 if attrs .get ("data" ) is not None :
132132 if isinstance (attrs .get ("data" ), str ):
133133 raise ValueError ("Unimplemented" )
134134 else :
135- data_value = paddle .to_tensor (
136- attrs .get ("data" ), dtype = data_type
137- ).reshape (attrs .get ("shape" ), [])
135+ data_value = paddle .reshape (
136+ paddle .to_tensor (attrs .get ("data" ), dtype = data_type ),
137+ attrs .get ("shape" , []),
138+ )
138139 yield {
139140 "info" : {
140141 "shape" : attrs .get ("shape" , []),
141142 "dtype" : data_type ,
142143 "device" : attrs .get ("device" , "gpu" ),
143144 "mean" : attrs .get ("mean" , 0.0 ),
144145 "std" : attrs .get ("std" , 1.0 ),
146+ "low" : attrs .get ("low" , 0 ),
147+ "high" : attrs .get ("high" , 2 ),
145148 },
146149 "data" : data_value ,
147150 "name" : attrs .get ("name" ),
@@ -163,11 +166,27 @@ def replay_tensor(info):
163166 device = info ["info" ]["device" ]
164167 dtype = info ["info" ]["dtype" ]
165168 shape = info ["info" ]["shape" ]
169+ min_value = info ["info" ]["low" ] if "low" in info ["info" ] else 0
170+ max_value = info ["info" ]["high" ] if "high" in info ["info" ] else 0.5
166171 if None in shape :
167172 shape = list (map (lambda i : i if i is not None else 1 , shape ))
168- mean = info ["info" ]["mean" ]
169- std = info ["info" ]["std" ]
170173 if "data" in info and info ["data" ] is not None :
171- return info ["data" ].to (device )
172-
173- return (paddle .randn (shape ).cast (dtype ).to (device ) * std * 1e-3 + 1e-2 ).cast (dtype )
174+ return paddle .reshape (info ["data" ], shape ).to (dtype ).to (device )
175+ elif dtype == paddle .int32 or dtype == paddle .int64 :
176+ return paddle .cast (
177+ paddle .randint (low = min_value , high = max_value , shape = shape , dtype = "int64" ),
178+ dtype ,
179+ ).to (device )
180+ elif dtype == paddle .bool :
181+ return paddle .cast (
182+ paddle .randint (low = 0 , high = 2 , shape = shape , dtype = "int32" ),
183+ paddle .bool ,
184+ ).to (device )
185+ else :
186+ std = info ["info" ]["std" ]
187+ # return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
188+ return (
189+ paddle .uniform (shape , dtype = "float32" , min = min_value , max = max_value )
190+ .to (dtype )
191+ .to (device )
192+ )
0 commit comments