Skip to content

Commit e0281a3

Browse files
committed
Add support for type annotation
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent a46dcc2 commit e0281a3

File tree

3 files changed

+59
-19
lines changed

3 files changed

+59
-19
lines changed

onnxscript/converter.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,16 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue:
464464
)
465465
) from e
466466

467+
def _get_type_annotation(self, annotation: ast.Expr) -> Optional[ta.TypeAnnotationValue]:
468+
typeinfo = self._eval_constant_expr(annotation)
469+
if not ta.is_valid_type(typeinfo):
470+
self.warn(
471+
annotation,
472+
"Unsupported type annotation.",
473+
)
474+
typeinfo = None
475+
return typeinfo
476+
467477
def _translate_attr(
468478
self,
469479
attr_name: str,
@@ -985,9 +995,11 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
985995
lhs = lhs.id
986996
t = self._translate_expr(rhs, lhs)
987997
if isinstance(stmt, ast.AnnAssign):
988-
typeinfo = self._eval_constant_expr(stmt.annotation)
998+
typeinfo = self._get_type_annotation(stmt.annotation)
989999
else:
9901000
typeinfo = None
1001+
if typeinfo is not None:
1002+
irbuilder.set_type_info(t, typeinfo)
9911003
var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
9921004
self._bind(lhs, var)
9931005
elif isinstance(lhs, ast.Tuple):
@@ -1400,13 +1412,7 @@ def _translate_function_signature_common(
14001412
else:
14011413
default_value = None
14021414
if x.annotation:
1403-
typeinfo = self._eval_constant_expr(x.annotation)
1404-
if not ta.is_valid_type(typeinfo):
1405-
self.warn(
1406-
x.annotation,
1407-
f"Unsupported type annotation for argument {x.arg}.",
1408-
)
1409-
typeinfo = None
1415+
typeinfo = self._get_type_annotation(x.annotation)
14101416
else:
14111417
# The code can only be exported as a function.
14121418
typeinfo = None

onnxscript/converter_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,37 @@ def model(x: FLOAT[10]) -> FLOAT[10]:
740740
model_false = make_model(False)
741741
onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto())
742742

743+
def test_type_annotation(self):
744+
"""Test that type annotations are processed correctly."""
745+
746+
@script()
747+
def model(x: FLOAT[10]) -> FLOAT[10]:
748+
temp: FLOAT[10] = op.Add(x, x)
749+
y = op.Mul(temp, temp)
750+
return y
751+
752+
model_proto = model.to_model_proto()
753+
input_type = model_proto.graph.input[0].type.tensor_type
754+
output_type = model_proto.graph.output[0].type.tensor_type
755+
temp_value_info = None
756+
for value_info in model_proto.graph.value_info:
757+
if value_info.name == "temp":
758+
temp_value_info = value_info
759+
break
760+
self.assertIsNotNone(temp_value_info, "ValueInfo for 'temp' not found in graph.")
761+
temp_type = temp_value_info.type.tensor_type
762+
self.assertEqual(temp_type.elem_type, onnx.TensorProto.FLOAT)
763+
self.assertEqual(len(temp_type.shape.dim), 1)
764+
self.assertEqual(temp_type.shape.dim[0].dim_value, 10)
765+
766+
self.assertEqual(input_type.elem_type, onnx.TensorProto.FLOAT)
767+
self.assertEqual(len(input_type.shape.dim), 1)
768+
self.assertEqual(input_type.shape.dim[0].dim_value, 10)
769+
770+
self.assertEqual(output_type.elem_type, onnx.TensorProto.FLOAT)
771+
self.assertEqual(len(output_type.shape.dim), 1)
772+
self.assertEqual(output_type.shape.dim[0].dim_value, 10)
773+
743774

744775
if __name__ == "__main__":
745776
unittest.main(verbosity=2)

onnxscript/irbuilder.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,24 @@ def to_function_proto(self) -> onnx.FunctionProto:
121121
# IRBuilder: abstracts out details of the IR in the python-to-IR converter
122122

123123

124+
def set_type_info(value: ir.Value, typeinfo: TypeAnnotationValue) -> None:
125+
"""Sets the type information on an IR value."""
126+
try:
127+
type_and_shape = ir.from_proto(typeinfo.to_type_proto())
128+
value.type = type_and_shape.type
129+
value.shape = type_and_shape.shape
130+
except AttributeError:
131+
pass
132+
value.meta["typeinfo"] = typeinfo
133+
134+
124135
def _make_value(
125136
varname: str, typeinfo: TypeAnnotationValue, sourceinfo: SourceInfo
126137
) -> ir.Value:
127-
if typeinfo is None:
128-
value = ir.Value(name=varname)
129-
else:
130-
try:
131-
type_and_shape = ir.from_proto(typeinfo.to_type_proto())
132-
value = ir.Value(
133-
name=varname, type=type_and_shape.type, shape=type_and_shape.shape
134-
)
135-
except AttributeError:
136-
value = ir.Value(name=varname)
138+
value = ir.Value(name=varname)
137139
value.meta.setdefault("sourceinfo", sourceinfo)
138-
value.meta.setdefault("typeinfo", typeinfo)
140+
if typeinfo is not None:
141+
set_type_info(value, typeinfo)
139142
return value
140143

141144

0 commit comments

Comments
 (0)