44import onnxscript as ost
55from onnxscript import opset19 as op # opset19 is the lastest by 202309
66
7+ np .random .seed (0 )
8+
79def make_model_and_data (model , * args , ** kwargs ):
810 name = model ._name
911
1012 # TODO: support multiple outputs
11- output = model (* args , ** kwargs ) # eager mode
13+ output = model (* args ) # eager mode
1214
1315 # Save model
1416 model_proto = model .to_model_proto ()
@@ -24,13 +26,20 @@ def make_model_and_data(model, *args, **kwargs):
2426 onnx .save (model_proto_ , save_path )
2527
2628 # Save inputs and output
29+ inputs = args
30+ if "force_saving_input_as_dtype_float32" in kwargs and kwargs ["force_saving_input_as_dtype_float32" ]:
31+ inputs = []
32+ for input in args :
33+ inputs .append (input .astype (np .float32 ))
2734 if len (args ) == 1 :
2835 input_file = os .path .join ("data" , "input_" + name )
29- np .save (input_file , args [0 ])
36+ np .save (input_file , inputs [0 ])
3037 else :
31- for idx , input in enumerate (args , start = 0 ):
38+ for idx , input in enumerate (inputs , start = 0 ):
3239 input_files = os .path .join ("data" , "input_" + name + "_" + str (index ))
3340 np .save (input_files , input )
41+ if "force_saving_output_as_dtype_float32" in kwargs and kwargs ["force_saving_output_as_dtype_float32" ]:
42+ output = output .astype (np .float32 )
3443 output_files = os .path .join ("data" , "output_" + name )
3544 np .save (output_files , output )
3645
@@ -48,3 +57,14 @@ def gather_shared_indices(x: ost.FLOAT[2, 1, 3, 4]) -> ost.FLOAT[3, 4]:
4857 y1 = op .Gather (y0 , indices , axis = 0 )
4958 return y1
5059make_model_and_data (gather_shared_indices , np .random .rand (2 , 1 , 3 , 4 ).astype (np .float32 ))
60+
61+ '''
62+ [Input] -> Greater(B=61) -> [Output]
63+ \
64+ dtype=np.int64
65+ '''
66+ @ost .script ()
67+ def greater_input_dtype_int64 (x : ost .FLOAT [27 , 9 ]) -> ost .BOOL [27 , 9 ]:
68+ y = op .Greater (x , op .Constant (value = onnx .helper .make_tensor ("" , onnx .TensorProto .INT64 , [], np .array ([61 ], dtype = np .int64 ))))
69+ return y
70+ make_model_and_data (greater_input_dtype_int64 , np .random .randint (0 , 100 , size = [27 , 9 ], dtype = np .int64 ), force_saving_input_as_dtype_float32 = True , force_saving_output_as_dtype_float32 = True )
0 commit comments