Skip to content

Commit 519ef5a

Browse files
Cleanup SymbolValue (#2752)
More cleanup (in the migration of converter to onnx IR): * Eliminate some redundant info/logic (class Dynamic) * Cleanup AttrRef and SymbolValue. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 43b1b74 commit 519ef5a

File tree

3 files changed

+74
-106
lines changed

3 files changed

+74
-106
lines changed

onnxscript/_internal/converter.py

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -339,22 +339,20 @@ def tensor_name_generator() -> str:
339339
def _to_onnx_attr_ref(
340340
self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
341341
) -> ir.Attr:
342-
pytype = val.typeinfo
343-
attrtype = ta.pytype_to_attrtype(pytype)
342+
attrtype = val.value.type
344343
attrname = None
345-
if attrtype is onnx.AttributeProto.FLOAT:
344+
if attrtype is ir.AttributeType.FLOAT: # onnx.AttributeProto.FLOAT:
346345
attrname = "value_float"
347-
elif attrtype is onnx.AttributeProto.INT:
346+
elif attrtype is ir.AttributeType.INT:
348347
attrname = "value_int"
349-
elif attrtype is onnx.AttributeProto.STRING:
348+
elif attrtype is ir.AttributeType.STRING:
350349
attrname = "value_string"
351-
elif attrtype is onnx.AttributeProto.INTS:
350+
elif attrtype is ir.AttributeType.INTS:
352351
attrname = "value_ints"
353352
else:
354-
msg = f"Unsupported attribute type {pytype!r}."
353+
msg = f"Unsupported attribute type {attrtype!r}."
355354
fail(info.msg(msg) if info else msg)
356-
attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype))
357-
return ir.Attr(attrname, attr_type, None, val.value)
355+
return ir.Attr(attrname, attrtype, value=None, ref_attr_name=val.value.name)
358356

359357
def _to_onnx_var(
360358
self,
@@ -369,7 +367,7 @@ def _to_onnx_var(
369367
result = self.emit(
370368
[result_name], values.Op(self.default_opset, "Constant"), [], [attr]
371369
)
372-
if ta.base_type_is_bool(val.typeinfo):
370+
if val.as_bool:
373371
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
374372
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
375373
# to promote a (python) bool attribute to a ONNX bool tensor.
@@ -384,8 +382,9 @@ def _to_onnx_var(
384382
)
385383
self._castable.add(result_name)
386384
return result
387-
if isinstance(val, values.Dynamic):
388-
return val.value
385+
if isinstance(val, values.SymbolValue):
386+
if isinstance(val.value, ir.Value):
387+
return val.value
389388
# Assume value is a python-value convertible to a tensor
390389
# TODO: check if value is convertible to a TensorProto, so that we can
391390
# produce a better error _message otherwise
@@ -534,29 +533,44 @@ def _translate_attr(
534533

535534
if isinstance(expr, ast.Name):
536535
val = self._lookup(expr.id, self._source_of(expr))
537-
if isinstance(val, values.AttrRef):
538-
attr_type = ir.AttributeType(ta.pytype_to_attrtype(val.typeinfo))
539-
attr_ref = ir.Attr(attr_name, attr_type, None, val.value)
540-
if attr_meta is not None and (attr_ref.type != attr_meta.type):
541-
self.fail(
542-
expr,
543-
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
536+
if isinstance(val, values.SymbolValue):
537+
val = val.value
538+
if isinstance(val, ir.Attr):
539+
# A reference to an attribute parameter:
540+
attr = val
541+
attr_ref = ir.Attr(
542+
attr_name, attr.type, value=None, ref_attr_name=attr.name
544543
)
545-
return attr_ref
546-
if isinstance(val, irbuilder.IRFunction):
547-
# Check that outer-scope variables referenced by function have same value
548-
# at function-definition site and use-as-attribute site, to avoid errors.
549-
for pyvar, previous in val.outer_scope_variables:
550-
current = self._lookup(pyvar, self._source_of(expr))
551-
if current.value != previous.value:
544+
if attr_meta is not None and (attr.type != attr_meta.type):
552545
self.fail(
553546
expr,
554-
f"Outer scope variable '{pyvar}' referenced by function "
555-
f"'{expr.id!r}' modified.",
547+
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
556548
)
557-
558-
# Create GraphProto attribute
559-
val = val.to_graph_proto()
549+
return attr_ref
550+
if isinstance(val, irbuilder.IRFunction):
551+
# A reference to a nested-function: convert to GraphProto and use it.
552+
irfunction = val
553+
# Check that outer-scope variables referenced by function have same value
554+
# at function-definition site and use-as-attribute site, to avoid errors.
555+
for pyvar, previous in irfunction.outer_scope_variables:
556+
current = self._lookup(pyvar, self._source_of(expr))
557+
if current.value != previous.value:
558+
self.fail(
559+
expr,
560+
f"Outer scope variable '{pyvar}' referenced by function "
561+
f"'{expr.id!r}' modified.",
562+
)
563+
# Create GraphProto attribute
564+
val = irfunction.to_graph_proto()
565+
if isinstance(val, ir.Value):
566+
self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.")
567+
else:
568+
# Treat as a constant python-value, to be converted below.
569+
pass
570+
else:
571+
# This must be a reference to an outer-scope python-value, typically a constant.
572+
# The value will be converted to an ONNX attribute value below.
573+
pass
560574
else:
561575
val = self._eval_constant_expr(expr)
562576

@@ -1045,7 +1059,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
10451059
typeinfo = None
10461060
if typeinfo is not None:
10471061
set_type_info(t, typeinfo)
1048-
var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
1062+
var = values.SymbolValue(t, info)
10491063
self._bind(lhs, var)
10501064
elif isinstance(lhs, ast.Tuple):
10511065
# Assignments of the form "x, y, z = op.SomeOp(...)"
@@ -1068,9 +1082,7 @@ def generate_onnx_name(x: ast.AST):
10681082
for x, output in zip(lhs.elts, outputs):
10691083
self._bind(
10701084
x.id,
1071-
values.Dynamic(
1072-
output, values.DynamicKind.Intermediate, self._source_of(x)
1073-
),
1085+
values.SymbolValue(output, self._source_of(x)),
10741086
)
10751087
else:
10761088
self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'")
@@ -1117,10 +1129,11 @@ def ret(exp, i, suffix):
11171129
preferred_name = f"return_val{suffix}"
11181130
return_var = self._translate_expr(exp, preferred_name) # TODO(rama)
11191131
val = self._lookup(return_var.name, self._source_of(exp), False)
1120-
if val and val.kind == values.DynamicKind.Input:
1121-
# In ONNX, a graph-input cannot be an output of the graph.
1122-
# We need to insert a copy.
1123-
return_var = self._emit_copy(return_var, preferred_name)
1132+
if isinstance(val, values.SymbolValue) and isinstance(val.value, ir.Value):
1133+
if val.value.is_graph_input():
1134+
# In ONNX, a graph-input cannot be an output of the graph.
1135+
# We need to insert a copy.
1136+
return_var = self._emit_copy(return_var, preferred_name)
11241137
for prev_output in self._current_fn.outputs:
11251138
if prev_output.name == return_var.name:
11261139
# ONNX does not allow duplicate output names.
@@ -1190,7 +1203,7 @@ def rename(x):
11901203
for x, y in zip(live_defs, if_outputs):
11911204
self._bind(
11921205
x,
1193-
values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)),
1206+
values.SymbolValue(y, self._source_of(stmt)),
11941207
)
11951208

11961209
def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
@@ -1257,7 +1270,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12571270
self._current_fn.append_parameter(onnx_loop_var)
12581271
self._bind(
12591272
python_loop_var_name,
1260-
values.Dynamic(onnx_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
1273+
values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)),
12611274
)
12621275

12631276
self._current_fn.append_parameter(
@@ -1278,9 +1291,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12781291
)
12791292
self._bind(
12801293
pv,
1281-
values.Dynamic(
1294+
values.SymbolValue(
12821295
ir.Value(name=onnx_var_name),
1283-
values.DynamicKind.Loop,
12841296
self._source_of(loop_stmt),
12851297
),
12861298
)
@@ -1376,7 +1388,7 @@ def rename(x):
13761388
if isinstance(loop_outputs, ir.Value):
13771389
loop_outputs = [loop_outputs]
13781390
for x, loop_output in zip(outputs, loop_outputs):
1379-
self._bind(x, values.Dynamic(loop_output, values.DynamicKind.Output, info))
1391+
self._bind(x, values.SymbolValue(loop_output, info))
13801392

13811393
def _translate_block(
13821394
self,
@@ -1431,7 +1443,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
14311443
function_ir.outer_scope_variables = [
14321444
(var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
14331445
]
1434-
self._bind(fn.name, function_ir)
1446+
self._bind(fn.name, values.SymbolValue(function_ir, self._source_of(fn)))
14351447
# TODO: Does not yet handle nested functions within nested functions.
14361448
self._current_fn.add_nested_function(function_ir)
14371449

@@ -1459,16 +1471,15 @@ def _translate_function_signature_common(
14591471
attribute_type = ta.pytype_to_attrtype(typeinfo)
14601472
attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None)
14611473
self._current_fn.append_parameter(attr)
1462-
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
1474+
as_bool = ta.base_type_is_bool(typeinfo)
1475+
self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x)))
14631476
else:
14641477
onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x))
14651478
self._current_fn.append_parameter(onnx_parameter)
14661479
self._used_vars.add(x.arg)
14671480
self._bind(
14681481
x.arg,
1469-
values.Dynamic(
1470-
onnx_parameter, values.DynamicKind.Input, self._source_of(x)
1471-
),
1482+
values.SymbolValue(onnx_parameter, self._source_of(x)),
14721483
)
14731484
if fn.returns:
14741485
type_annotation = self._eval_constant_expr(fn.returns)

onnxscript/_internal/values.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import logging
1212
import types
1313
import typing
14-
from enum import IntFlag
1514
from typing import ( # type: ignore[attr-defined]
16-
TYPE_CHECKING,
1715
Any,
1816
Callable,
1917
ClassVar,
@@ -22,7 +20,6 @@
2220
Protocol,
2321
Sequence,
2422
TypeVar,
25-
_GenericAlias,
2623
)
2724

2825
import onnx
@@ -36,9 +33,6 @@
3633
from onnxscript.ir import _schemas
3734
from onnxscript.onnx_types import ONNXType
3835

39-
if TYPE_CHECKING:
40-
from onnxscript._internal.type_annotation import TypeAnnotationValue
41-
4236
_R = TypeVar("_R")
4337
_P = ParamSpec("_P")
4438

@@ -853,61 +847,28 @@ def ThresholdedRelu(X, alpha: float):
853847
* To represent constant-values, translated into ONNX constants.
854848
"""
855849

856-
def __init__(self, info: sourceinfo.SourceInfo) -> None:
850+
def __init__(self, value: Any, info: sourceinfo.SourceInfo) -> None:
851+
"""
852+
Initializes SymbolValue.
853+
854+
Arguments:
855+
value: The value bound to a python variable in a script.
856+
info: source-location information for error-messages/debugging
857+
"""
857858
if not isinstance(info, sourceinfo.SourceInfo):
858859
raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.")
860+
self.value = value
859861
self.info = info
860862

861863

862864
class AttrRef(SymbolValue):
863-
def __init__(
864-
self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo
865-
) -> None:
865+
def __init__(self, attr: ir.Attr, as_bool: bool, info: sourceinfo.SourceInfo) -> None:
866866
"""Initializes AttrRef.
867867
868868
Arguments:
869-
attr_name: name of the attribute-parameter
870-
typeinfo: type annotation of the attribute.
871-
op's attributes in ONNX are usually single type or list of single type.
869+
attr: An ir.Attr representing the attribute-parameter
870+
as_bool: Whether the attribute is to be interpreted as a bool type (represented as int in ONNX)
872871
info: for debugging use.
873872
"""
874-
super().__init__(info)
875-
self.value = attr_name
876-
self.typeinfo = typeinfo
877-
if not isinstance(typeinfo, (type, _GenericAlias)):
878-
# typing._GenericAlias for List[int] and List[str], etc.
879-
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
880-
self.typeinfo = typeinfo
881-
882-
883-
class DynamicKind(IntFlag):
884-
Unknown = 0
885-
Input = 1
886-
Output = 2
887-
Intermediate = 4
888-
Loop = 8
889-
890-
891-
class Dynamic(SymbolValue):
892-
def __init__(
893-
self,
894-
onnx_var: ir.Value,
895-
kind: DynamicKind,
896-
info: sourceinfo.SourceInfo,
897-
typeinfo: TypeAnnotationValue | None = None,
898-
) -> None:
899-
"""Represents an ir.Value with some extra information.
900-
901-
Arguments:
902-
onnx_var: the name of the ONNX variable used to represent this value
903-
kind: the DynamicKind of this variable
904-
info: source-location information for error-messages/debugging
905-
typeinfo: type-information for the value
906-
"""
907-
super().__init__(info)
908-
assert isinstance(kind, DynamicKind)
909-
if not isinstance(onnx_var, ir.Value):
910-
raise TypeError(f"onnx_var must be of type ir.Value not {type(onnx_var)!r}.")
911-
self.value = onnx_var
912-
self.kind = kind
913-
self.typeinfo = typeinfo
873+
super().__init__(attr, info)
874+
self.as_bool = as_bool

onnxscript/values.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from onnxscript._internal.values import (
99
AttrRef,
10-
Dynamic,
11-
DynamicKind,
1210
OnnxClosure,
1311
OnnxFunction,
1412
Op,
@@ -21,8 +19,6 @@
2119

2220
__all__ = [
2321
"AttrRef",
24-
"Dynamic",
25-
"DynamicKind",
2622
"OnnxClosure",
2723
"OnnxFunction",
2824
"Op",

0 commit comments

Comments
 (0)