77
88import numpy as np
99import onnx
10+ import onnx_ir as ir
1011
1112from onnxscript import tensor
1213
@@ -87,6 +88,41 @@ def value_to_type_proto(val):
8788 raise ValueError (f"Value of type { type (val )} is invalid as an ONNX input/output." )
8889
8990
91+ def value_to_type (val ):
92+ """Return an ir.Value representation of a python-value."""
93+ if isinstance (val , (np .ndarray , tensor .Tensor )):
94+ elem_type = onnx .helper .np_dtype_to_tensor_dtype (val .dtype ) # noqa: TID251
95+ shape = val .shape
96+ return (ir .TensorType (elem_type ), shape )
97+ elif isinstance (val , int ):
98+ elem_type = onnx .TensorProto .INT32
99+ shape = []
100+ return (ir .TensorType (elem_type ), shape )
101+ elif isinstance (val , (float , np .float32 )):
102+ elem_type = onnx .TensorProto .FLOAT
103+ shape = []
104+ return (ir .TensorType (elem_type ), shape )
105+ elif isinstance (val , list ):
106+ if len (val ) > 0 :
107+ type , shape = value_to_type (val [0 ])
108+ return ir .SequenceType (type ), shape
109+ # Edge-case. Cannot determine a suitable ONNX type for an empty list.
110+ # Should be using a typed-value instead.
111+ # Treated as a sequence of tensors of float-type.
112+ return ir .SequenceType (ir .TensorType (onnx .TensorProto .FLOAT )), None
113+ if isinstance (val , numbers .Number ):
114+ nparray = np .array (val )
115+ elem_type = onnx .helper .np_dtype_to_tensor_dtype (nparray .dtype ) # noqa: TID251
116+ return ir .TensorType (elem_type ), []
117+ raise ValueError (f"Value of type { type (val )} is invalid as an ONNX input/output." )
118+
119+
120+ def value_to_ir_value (name : str , val ) -> ir .Value :
121+ """Return an ir.Value representation of a python-value."""
122+ type , shape = value_to_type (val )
123+ return ir .Value (name = name , type = type , shape = shape )
124+
125+
90126def values_to_value_infos (name_values ):
91127 """Create a list of ValueInfoProto from a list of (name, value) pairs,
92128 skipping any None values.
@@ -96,3 +132,10 @@ def values_to_value_infos(name_values):
96132 for (name , val ) in name_values
97133 if val is not None
98134 ]
135+
136+
137+ def values_to_ir_values (name_values ):
138+ """Create a list of ir.Value from a list of (name, value) pairs,
139+ skipping any None values.
140+ """
141+ return [value_to_ir_value (name , val ) for (name , val ) in name_values if val is not None ]
0 commit comments