2121 Union ,
2222)
2323
24- import onnx
2524import onnx_ir as ir
2625from onnxscript .ir import _schemas
2726
2827import onnxscript
2928from onnxscript import irbuilder , onnx_types , sourceinfo , values
3029from onnxscript import type_annotation as ta
31- from onnxscript ._internal import _analysis , ast_utils , autocast , param_manipulation
30+ from onnxscript ._internal import _analysis , ast_utils , autocast
3231
3332if TYPE_CHECKING :
3433 # The type-alias LocalSymValue represents the types of values that local names in a
3534 # script-function may be bound to during translation, (ONNX IR values).
3635 # TODO(rama): Rationalize this and values.SymbolValue
3736
38- LocalSymValue = Union [values .SymbolValue , irbuilder . IRFunction ]
37+ LocalSymValue = Union [values .SymbolValue , ir . Function ]
3938
4039 # The type-alias PyValue is used to represent the types of python values that may be used
4140 # in an ONNX Script function.
@@ -115,28 +114,11 @@ def ignore(cond, msg):
115114}
116115
117116
118- class Variable :
119- """Represents an ONNX variable.
117+ _CASTABLE_FIELD = "pkg.onnxscript.converter.castable"
120118
121- TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this
122- converter.
123- """
124-
125- def __init__ (self , name : str , castable : bool = False ):
126- """Initialize the instance.
127-
128- Args:
129- name: Name of the ONNX variable
130- castable: Whether this variable is castable to a desired target type.
131- Used for ONNX variables representing constants created from python values
132- like 0 or 1 or 0.5 which are treated as polymorphic values castable to other
133- types as needed.
134- """
135- self .name = name
136- self .is_castable = castable
137-
138- def __str__ (self ) -> str :
139- return self .name
119+ def mark_castable (value : ir .Value ):
120+ """Mark an ONNX value as auto-castable."""
121+ value .meta [_CASTABLE_FIELD ] = True
140122
141123
142124@dataclasses .dataclass
@@ -227,6 +209,8 @@ def __init__(
227209 graph = ir .Graph ((), (), nodes = []),
228210 attributes = {},
229211 )
212+ # A mapping from value names to the values for each function
213+ # self._scoped_values: dict[ir.Function, dict[str, ir.Value]] = {}
230214 self ._nextvar : int = 0
231215 self ._used_vars : set [str ] = set ()
232216 self ._locals : list [dict [str , LocalSymValue ]] = [{}]
@@ -325,26 +309,25 @@ def _to_onnx_var(
325309 target : PreferredName = "tmp" ,
326310 * ,
327311 info : sourceinfo .SourceInfo ,
328- ) -> Variable :
312+ ) -> ir . Value :
329313 """Convert a value to an ONNX variable."""
330314 if isinstance (val , values .AttrRef ):
331315 # promote attribute to value
332316 result = self ._generate_unique_name (target )
333317 attr = _to_onnx_ref_attr (val , info )
334- self .emit ([], "Constant" , [result ], attrs = [attr ])
318+ result_val = self .emit ([result ], "Constant" , [], attrs = [attr ])[ 0 ]
335319 if ta .base_type_is_bool (val .typeinfo ):
336320 # ONNX attributes use an int-encoding for bools, but ONNX tensor types
337321 # distinguish between int and bool. So we cast the int tensor to a bool tensor,
338322 # to promote a (python) bool attribute to a ONNX bool tensor.
339323 result_as_bool = self ._generate_unique_name (result + "_as_bool" )
340- self .emit (
341- [result ],
342- "Cast" ,
324+ return self .emit (
343325 [result_as_bool ],
326+ "Cast" ,
327+ [result ],
344328 attrs = [ir .AttrInt64 ("to" , ir .DataType .BOOL )],
345- )
346- return Variable (result_as_bool , castable = True )
347- return Variable (result , castable = True )
329+ )[0 ]
330+ return result_val
348331
349332 if isinstance (val , values .Dynamic ):
350333 return Variable (val .value )
@@ -364,16 +347,17 @@ def emit(
364347 * ,
365348 attrs : Sequence [ir .Attr ] = (),
366349 domain : str = "" ,
367- ):
350+ ) -> Sequence [ ir . Value ] :
368351 """Emit an ONNX operator with the given inputs, outputs, and attributes."""
369352 node = ir .Node (
370353 domain = domain ,
371354 op_type = op_type ,
372- inputs = [self ._lookup (inp , self ._source_of (inputs [ 0 ] )) for inp in inputs ],
355+ inputs = [self ._lookup (inp , self ._source_of (inp )) for inp in inputs ],
373356 attributes = attrs ,
374- outputs = [self ._lookup (out , self ._source_of (outputs [ 0 ] )) for out in outputs ],
357+ outputs = [self ._lookup (out , self ._source_of (out )) for out in outputs ],
375358 )
376359 self ._current_fn .append (node )
360+ return node .outputs
377361
378362 def _emit_const (
379363 self ,
0 commit comments