Skip to content

Commit 89c802a

Browse files
committed
Cleanup attr ref representation
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent aec7b2e commit 89c802a

File tree

2 files changed

+14
-27
lines changed

2 files changed

+14
-27
lines changed

onnxscript/_internal/converter.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -339,22 +339,20 @@ def tensor_name_generator() -> str:
339339
def _to_onnx_attr_ref(
340340
self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
341341
) -> ir.Attr:
342-
pytype = val.typeinfo
343-
attrtype = ta.pytype_to_attrtype(pytype)
342+
attrtype = val.value.type
344343
attrname = None
345-
if attrtype is onnx.AttributeProto.FLOAT:
344+
if attrtype is ir.AttributeType.FLOAT: # onnx.AttributeProto.FLOAT:
346345
attrname = "value_float"
347-
elif attrtype is onnx.AttributeProto.INT:
346+
elif attrtype is ir.AttributeType.INT:
348347
attrname = "value_int"
349-
elif attrtype is onnx.AttributeProto.STRING:
348+
elif attrtype is ir.AttributeType.STRING:
350349
attrname = "value_string"
351-
elif attrtype is onnx.AttributeProto.INTS:
350+
elif attrtype is ir.AttributeType.INTS:
352351
attrname = "value_ints"
353352
else:
354-
msg = f"Unsupported attribute type {pytype!r}."
353+
msg = f"Unsupported attribute type {attrtype!r}."
355354
fail(info.msg(msg) if info else msg)
356-
attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype))
357-
return ir.Attr(attrname, attr_type, value=None, ref_attr_name=val.value.name)
355+
return ir.Attr(attrname, attrtype, value=None, ref_attr_name=val.value.name)
358356

359357
def _to_onnx_var(
360358
self,
@@ -369,7 +367,7 @@ def _to_onnx_var(
369367
result = self.emit(
370368
[result_name], values.Op(self.default_opset, "Constant"), [], [attr]
371369
)
372-
if ta.base_type_is_bool(val.typeinfo):
370+
if val.as_bool:
373371
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
374372
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
375373
# to promote a (python) bool attribute to a ONNX bool tensor.
@@ -1474,7 +1472,8 @@ def _translate_function_signature_common(
14741472
attribute_type = ta.pytype_to_attrtype(typeinfo)
14751473
attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None)
14761474
self._current_fn.append_parameter(attr)
1477-
self._bind(x.arg, values.AttrRef(attr, typeinfo, self._source_of(x)))
1475+
as_bool = ta.base_type_is_bool(typeinfo)
1476+
self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x)))
14781477
else:
14791478
onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x))
14801479
self._current_fn.append_parameter(onnx_parameter)

onnxscript/_internal/values.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import typing
1414
from enum import IntFlag
1515
from typing import ( # type: ignore[attr-defined]
16-
TYPE_CHECKING,
1716
Any,
1817
Callable,
1918
ClassVar,
@@ -22,7 +21,6 @@
2221
Protocol,
2322
Sequence,
2423
TypeVar,
25-
_GenericAlias,
2624
)
2725

2826
import onnx
@@ -36,9 +34,6 @@
3634
from onnxscript.ir import _schemas
3735
from onnxscript.onnx_types import ONNXType
3836

39-
if TYPE_CHECKING:
40-
from onnxscript._internal.type_annotation import TypeAnnotationValue
41-
4237
_R = TypeVar("_R")
4338
_P = ParamSpec("_P")
4439

@@ -868,23 +863,16 @@ def __init__(self, value: Any, info: sourceinfo.SourceInfo) -> None:
868863

869864

870865
class AttrRef(SymbolValue):
871-
def __init__(
872-
self, attr: ir.Attr, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo
873-
) -> None:
866+
def __init__(self, attr: ir.Attr, as_bool: bool, info: sourceinfo.SourceInfo) -> None:
874867
"""Initializes AttrRef.
875868
876869
Arguments:
877870
attr: An ir.Attr representing the attribute-parameter
878-
typeinfo: type annotation of the attribute.
879-
op's attributes in ONNX are usually single type or list of single type.
871+
as_bool: Whether the attribute is to be interpreted as a bool type (represented as int in ONNX)
880872
info: for debugging use.
881873
"""
882874
super().__init__(attr, info)
883-
self.typeinfo = typeinfo
884-
if not isinstance(typeinfo, (type, _GenericAlias)):
885-
# typing._GenericAlias for List[int] and List[str], etc.
886-
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
887-
self.typeinfo = typeinfo
875+
self.as_bool = as_bool
888876

889877

890878
class DynamicKind(IntFlag):
@@ -901,7 +889,7 @@ def __init__(
901889
ir_value: ir.Value,
902890
kind: DynamicKind,
903891
info: sourceinfo.SourceInfo,
904-
typeinfo: TypeAnnotationValue | None = None,
892+
typeinfo: type_annotation.TypeAnnotationValue | None = None,
905893
) -> None:
906894
"""Represents an ir.Value with some extra information.
907895

0 commit comments

Comments
 (0)