Skip to content

Commit 2196f99

Browse files
committed
return values
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 0af0707 commit 2196f99

File tree

2 files changed

+20
-35
lines changed

2 files changed

+20
-35
lines changed

onnxscript/_converter.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,20 @@
2121
Union,
2222
)
2323

24-
import onnx
2524
import onnx_ir as ir
2625
from onnxscript.ir import _schemas
2726

2827
import onnxscript
2928
from onnxscript import irbuilder, onnx_types, sourceinfo, values
3029
from onnxscript import type_annotation as ta
31-
from onnxscript._internal import _analysis, ast_utils, autocast, param_manipulation
30+
from onnxscript._internal import _analysis, ast_utils, autocast
3231

3332
if TYPE_CHECKING:
3433
# The type-alias LocalSymValue represents the types of values that local names in a
3534
# script-function may be bound to during translation, (ONNX IR values).
3635
# TODO(rama): Rationalize this and values.SymbolValue
3736

38-
LocalSymValue = Union[values.SymbolValue, irbuilder.IRFunction]
37+
LocalSymValue = Union[values.SymbolValue, ir.Function]
3938

4039
# The type-alias PyValue is used to represent the types of python values that may be used
4140
# in an ONNX Script function.
@@ -115,28 +114,11 @@ def ignore(cond, msg):
115114
}
116115

117116

118-
class Variable:
119-
"""Represents an ONNX variable.
117+
_CASTABLE_FIELD = "pkg.onnxscript.converter.castable"
120118

121-
TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this
122-
converter.
123-
"""
124-
125-
def __init__(self, name: str, castable: bool = False):
126-
"""Initialize the instance.
127-
128-
Args:
129-
name: Name of the ONNX variable
130-
castable: Whether this variable is castable to a desired target type.
131-
Used for ONNX variables representing constants created from python values
132-
like 0 or 1 or 0.5 which are treated as polymorphic values castable to other
133-
types as needed.
134-
"""
135-
self.name = name
136-
self.is_castable = castable
137-
138-
def __str__(self) -> str:
139-
return self.name
119+
def mark_castable(value: ir.Value):
120+
"""Mark an ONNX value as auto-castable."""
121+
value.meta[_CASTABLE_FIELD] = True
140122

141123

142124
@dataclasses.dataclass
@@ -227,6 +209,8 @@ def __init__(
227209
graph=ir.Graph((), (), nodes=[]),
228210
attributes={},
229211
)
212+
# A mapping from value names to the values for each function
213+
# self._scoped_values: dict[ir.Function, dict[str, ir.Value]] = {}
230214
self._nextvar: int = 0
231215
self._used_vars: set[str] = set()
232216
self._locals: list[dict[str, LocalSymValue]] = [{}]
@@ -325,26 +309,25 @@ def _to_onnx_var(
325309
target: PreferredName = "tmp",
326310
*,
327311
info: sourceinfo.SourceInfo,
328-
) -> Variable:
312+
) -> ir.Value:
329313
"""Convert a value to an ONNX variable."""
330314
if isinstance(val, values.AttrRef):
331315
# promote attribute to value
332316
result = self._generate_unique_name(target)
333317
attr = _to_onnx_ref_attr(val, info)
334-
self.emit([], "Constant", [result], attrs=[attr])
318+
result_val = self.emit([result], "Constant", [], attrs=[attr])[0]
335319
if ta.base_type_is_bool(val.typeinfo):
336320
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
337321
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
338322
# to promote a (python) bool attribute to a ONNX bool tensor.
339323
result_as_bool = self._generate_unique_name(result + "_as_bool")
340-
self.emit(
341-
[result],
342-
"Cast",
324+
return self.emit(
343325
[result_as_bool],
326+
"Cast",
327+
[result],
344328
attrs=[ir.AttrInt64("to", ir.DataType.BOOL)],
345-
)
346-
return Variable(result_as_bool, castable=True)
347-
return Variable(result, castable=True)
329+
)[0]
330+
return result_val
348331

349332
if isinstance(val, values.Dynamic):
350333
return Variable(val.value)
@@ -364,16 +347,17 @@ def emit(
364347
*,
365348
attrs: Sequence[ir.Attr] = (),
366349
domain: str = "",
367-
):
350+
) -> Sequence[ir.Value]:
368351
"""Emit an ONNX operator with the given inputs, outputs, and attributes."""
369352
node = ir.Node(
370353
domain=domain,
371354
op_type=op_type,
372-
inputs=[self._lookup(inp, self._source_of(inputs[0])) for inp in inputs],
355+
inputs=[self._lookup(inp, self._source_of(inp)) for inp in inputs],
373356
attributes=attrs,
374-
outputs=[self._lookup(out, self._source_of(outputs[0])) for out in outputs],
357+
outputs=[self._lookup(out, self._source_of(out)) for out in outputs],
375358
)
376359
self._current_fn.append(node)
360+
return node.outputs
377361

378362
def _emit_const(
379363
self,

onnxscript/values.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def __init__(
746746
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
747747
self.typeinfo = typeinfo
748748

749+
749750
class DynamicKind(IntFlag):
750751
Unknown = 0
751752
Input = 1

0 commit comments

Comments
 (0)