Skip to content

Commit 86e10dd

Browse files
committed
More cleanup
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent fa95a93 commit 86e10dd

File tree

1 file changed

+53
-86
lines changed

1 file changed

+53
-86
lines changed

onnxscript/converter.py

Lines changed: 53 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
266266
The block is translated into a nested-scope in ONNX.
267267
"""
268268
self._outer.insert(0, self._current_fn)
269-
self._current_fn = self.new_function(name)
269+
self._current_fn = irbuilder.IRFunction(name)
270270
self._locals.insert(0, {})
271271
logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node))
272272

@@ -379,22 +379,18 @@ def _to_onnx_var(
379379
def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value:
380380
return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info)
381381

382-
def new_function(self, name: str, domain: str = "", register: bool = False) -> irbuilder.IRFunction:
383-
if register and (domain, name) in self.ir_builder.functions:
384-
raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
385-
function = irbuilder.IRFunction(name, domain)
386-
if register:
387-
self.ir_builder.functions[domain, name] = function
388-
return function
389-
390-
def add_stmt(
382+
def emit(
391383
self,
392-
results: Sequence[str],
393-
callee: values.Op,
384+
outputs: Sequence[str],
385+
callee: values.Op | str,
394386
inputs: Sequence[Optional[ir.Value]],
395-
attrs: Sequence[ir.Attr],
396-
) -> Sequence[ir.Value]:
397-
output_values = [ir.Value(name=o) for o in results]
387+
attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None,
388+
) -> Sequence[ir.Value] | ir.Value:
389+
if not isinstance(callee, values.Op):
390+
callee = values.Op(self.default_opset, callee)
391+
if attrs is None:
392+
attrs = []
393+
output_values = [ir.Value(name=o) for o in outputs]
398394
node = ir.Node(
399395
domain=callee.opset.domain,
400396
version=callee.opset.version,
@@ -407,50 +403,7 @@ def add_stmt(
407403
raise TypeError(f"Unexpected type {type(callee)} for callee.")
408404
node.meta.setdefault("callee", callee)
409405
self._current_fn.append_node(node)
410-
return output_values
411406

412-
def add_attr_parameter(
413-
self,
414-
varname: str,
415-
attribute_type: onnx.AttributeProto.AttributeType,
416-
default_value: int | float | str | None,
417-
) -> None:
418-
attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None)
419-
self._current_fn.append_parameter(attr)
420-
421-
def add_input(
422-
self,
423-
varname: str,
424-
typeinfo: ta.TypeAnnotationValue,
425-
source_info: sourceinfo.SourceInfo,
426-
) -> None:
427-
self._current_fn.append_parameter(make_value(varname, typeinfo, source_info))
428-
429-
def add_output(
430-
self,
431-
varname: str,
432-
typeinfo: ta.TypeAnnotationValue,
433-
source_info: sourceinfo.SourceInfo,
434-
) -> None:
435-
self._current_fn.append_output(make_value(varname, typeinfo, source_info))
436-
437-
def emit(
438-
self,
439-
outputs: Sequence[str],
440-
callee: values.Op | str,
441-
inputs: Sequence[Optional[ir.Value]],
442-
attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None,
443-
) -> Sequence[ir.Value] | ir.Value:
444-
if not isinstance(callee, values.Op):
445-
callee = values.Op(self.default_opset, callee)
446-
if attrs is None:
447-
attrs = []
448-
output_values = self.add_stmt(
449-
outputs,
450-
callee,
451-
inputs,
452-
attrs,
453-
)
454407
return output_values if len(output_values) > 1 else output_values[0]
455408

456409
def emit1(self, *args, **kwargs) -> ir.Value:
@@ -1162,7 +1115,9 @@ def ret(exp, i, suffix):
11621115
t = None
11631116
else:
11641117
t = self.returntype[i]
1165-
self.add_output(return_var.name, t, self._source_of(stmt))
1118+
self._current_fn.append_output(
1119+
make_value(return_var.name, t, self._source_of(stmt))
1120+
)
11661121
return return_var
11671122

11681123
val = stmt.value
@@ -1283,10 +1238,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12831238
# build loop_body
12841239
self._enter_scope("loop_body", loop_stmt)
12851240
o_loop_var = self.generate_unique_name(p_loop_var)
1286-
self.add_input(
1287-
o_loop_var,
1288-
onnx_types.INT64,
1289-
self._source_of(loop_stmt),
1241+
self._current_fn.append_parameter(
1242+
make_value(
1243+
o_loop_var,
1244+
onnx_types.INT64,
1245+
self._source_of(loop_stmt),
1246+
)
12901247
)
12911248
self._bind(
12921249
p_loop_var,
@@ -1295,18 +1252,22 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12951252
),
12961253
)
12971254

1298-
self.add_input(
1299-
i_cond_var.name,
1300-
onnx_types.BOOL,
1301-
self._source_of(loop_stmt),
1255+
self._current_fn.append_parameter(
1256+
make_value(
1257+
i_cond_var.name,
1258+
onnx_types.BOOL,
1259+
self._source_of(loop_stmt),
1260+
)
13021261
)
13031262

13041263
for pv in loop_state_vars:
13051264
ov = self.generate_unique_name(pv)
13061265
# TODO: retrieve the annotation for variable pv is any is specified.
13071266
# typeinfo = self._eval_constant_expr(pv.annotation)
13081267
typeinfo = None
1309-
self.add_input(ov, typeinfo, self._source_of(loop_stmt))
1268+
self._current_fn.append_parameter(
1269+
make_value(ov, typeinfo, self._source_of(loop_stmt))
1270+
)
13101271
self._bind(
13111272
pv,
13121273
values.Dynamic(
@@ -1363,10 +1324,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13631324
[],
13641325
)
13651326

1366-
self.add_output(
1367-
o_cond_out,
1368-
onnx_types.BOOL,
1369-
self._source_of(loop_stmt),
1327+
self._current_fn.append_output(
1328+
make_value(
1329+
o_cond_out,
1330+
onnx_types.BOOL,
1331+
self._source_of(loop_stmt),
1332+
)
13701333
)
13711334
for pv in loop_state_vars:
13721335
ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
@@ -1379,7 +1342,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13791342
ov = self._emit_copy(ov, pv)
13801343
# TODO: retrieve variable type for the annotation if any.
13811344
typeinfo = None
1382-
self.add_output(ov.name, typeinfo, self._source_of(loop_stmt))
1345+
self._current_fn.append_output(
1346+
make_value(ov.name, typeinfo, self._source_of(loop_stmt))
1347+
)
13831348
body = self._exit_scope()
13841349
inputs = [o_loop_bound, o_loop_condition] + [
13851350
self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars
@@ -1424,10 +1389,12 @@ def _translate_block(
14241389
# To return an outer-scope variable, an ONNX Graph has to
14251390
# use an explicit copy via Identity.
14261391
output = self._emit_copy(output, pvar)
1427-
self.add_output(
1428-
output.name,
1429-
pv_val.typeinfo,
1430-
source,
1392+
self._current_fn.append_output(
1393+
make_value(
1394+
output.name,
1395+
pv_val.typeinfo,
1396+
source,
1397+
)
14311398
)
14321399
else:
14331400
pv_val = None
@@ -1446,7 +1413,7 @@ def _translate_block(
14461413

14471414
# TODO: retrieve the annotation if any.
14481415
typeinfo = None
1449-
self.add_output(ovar.name, typeinfo, source)
1416+
self._current_fn.append_output(make_value(ovar.name, typeinfo, source))
14501417
graph = self._exit_scope()
14511418
return graph.graph
14521419

@@ -1484,14 +1451,14 @@ def _translate_function_signature_common(
14841451
# The code can only be exported as a function.
14851452
typeinfo = None
14861453
if typeinfo and ta.is_attr_type(typeinfo):
1487-
self.add_attr_parameter(
1488-
x.arg,
1489-
ta.pytype_to_attrtype(typeinfo),
1490-
default_value,
1491-
)
1454+
attribute_type = ta.pytype_to_attrtype(typeinfo)
1455+
attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None)
1456+
self._current_fn.append_parameter(attr)
14921457
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
14931458
else:
1494-
self.add_input(x.arg, typeinfo, self._source_of(x))
1459+
self._current_fn.append_parameter(
1460+
make_value(x.arg, typeinfo, self._source_of(x))
1461+
)
14951462
self._used_vars.add(x.arg)
14961463
self._bind(
14971464
x.arg,
@@ -1533,7 +1500,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
15331500
if opset:
15341501
self._set_default_opset(opset, stmt)
15351502
domain = self.this_module.domain
1536-
self._current_fn = self.new_function(stmt.name, domain, True)
1503+
self._current_fn = irbuilder.IRFunction(stmt.name, domain)
15371504
self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals)
15381505
fn_ir = self._translate_function_def_common(stmt)
15391506
self.this_module.add_function_def(fn_ir)
@@ -1544,5 +1511,5 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
15441511
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
15451512
"""Translate a (top-level) function signature."""
15461513
domain = self.this_module.domain
1547-
self._current_fn = self.new_function(fn.name, domain, True)
1514+
self._current_fn = irbuilder.IRFunction(fn.name, domain)
15481515
return self._translate_function_signature_common(fn)

0 commit comments

Comments
 (0)