Skip to content

Commit 761451a

Browse files
committed
wip
Signed-off-by: Justin Chu <[email protected]>
1 parent 3aff171 commit 761451a

File tree

1 file changed

+39
-44
lines changed

1 file changed

+39
-44
lines changed

onnxscript/_converter.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,12 @@ def __init__(
177177

178178
# A stack of functions in the outer scope
179179
self._outer: list[ir.Function] = []
180-
self._current_fn: ir.Function | None = None
180+
self._current_fn: ir.Function = ir.Function(
181+
domain=self._this_module.domain,
182+
name="",
183+
graph=ir.Graph((), (), nodes=[]),
184+
attributes={},
185+
)
181186
self._nextvar: int = 0
182187
self._used_vars: set[str] = set()
183188
self._locals: list[dict[str, LocalSymValue]] = [{}]
@@ -225,13 +230,18 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
225230
def _init_function_translation(self) -> None:
226231
"""Initialize self for translating a new (top-level) function."""
227232
self._outer = []
228-
self._current_fn = None
233+
# TODO(justinchuby): Update this
234+
self._current_fn = ir.Function(
235+
domain=self._this_module.domain,
236+
name="",
237+
graph=ir.Graph((), (), nodes=[]),
238+
attributes={},
239+
)
229240
self._nextvar = 0
230241
self._used_vars = set()
231242
self._locals: List[Dict[str, LocalSymValue]] = [{}]
232243

233244
def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
234-
assert self._current_fn is not None
235245
return sourceinfo.SourceInfo(node, self._source, self._current_fn.name)
236246

237247
def _message(self, node: ast.AST, error_msg: str) -> str:
@@ -255,7 +265,6 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
255265
"""Enter a control-flow block (a loop body or if-then-else branch).
256266
The block is translated into a nested-scope in ONNX.
257267
"""
258-
assert self._current_fn is not None
259268
self._outer.append(self._current_fn)
260269
assert self._this_module is not None
261270
self._current_fn = ir.Function(
@@ -334,7 +343,9 @@ def _to_onnx_var(
334343
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
335344
# to promote a (python) bool attribute to a ONNX bool tensor.
336345
result_as_bool = self.generate_unique_name(result + "_as_bool")
337-
self.emit("Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)])
346+
self.emit(
347+
"Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)]
348+
)
338349
return Variable(result_as_bool, castable=True)
339350
return Variable(result, castable=True)
340351

@@ -364,7 +375,6 @@ def emit(
364375
attributes=attrs,
365376
outputs=[self._lookup(out, self._source_of(outputs[0])) for out in outputs],
366377
)
367-
assert self._current_fn is not None
368378
self._current_fn.append(node)
369379

370380
def _emit_const(
@@ -454,12 +464,12 @@ def _translate_attr(
454464
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
455465
)
456466
return attr_ref
457-
if isinstance(val, ir.Function):
458-
# if isinstance(val, irbuilder.IRFunction):
467+
if isinstance(val, ir.Graph):
468+
# if isinstance(val, irbuilder.IRFunction):
459469
# Check that outer-scope variables referenced by function have same value
460470
# at function-definition site and use-as-attribute site, to avoid errors.
461471

462-
# TODO(justinchuby): Capture outer_scope_variables
472+
# TODO(justinchuby): Capture outer_scope_variables?
463473
# And implement the following
464474
# for pyvar, previous in val.outer_scope_variables:
465475
# current = self._lookup(pyvar, self._source_of(expr))
@@ -470,9 +480,8 @@ def _translate_attr(
470480
# f"'{expr.id!r}' modified.",
471481
# )
472482

473-
# Create GraphProto attribute
474-
# TODO: Fix this
475-
val = val.to_graph_proto()
483+
# Create Graph attribute
484+
pass
476485
else:
477486
val = self._eval_constant_expr(expr)
478487

@@ -482,25 +491,15 @@ def _translate_attr(
482491
# The caller is responsible for omitting such attribute-values from the list of attributes
483492
# in a NodeProto.
484493
if val is None:
485-
if attr_meta and attr_meta.required:
486-
self.fail(expr, f"Attribute '{attr_name}' is required.")
487494
return None
488-
attr_type = int(attr_meta.type) if attr_meta else None
489-
attr = self._make_onnx_attr(attr_name, val, attrtype=attr_type)
490-
if attr_meta and (attr.type != attr_meta.type):
491-
self.fail(
492-
expr,
493-
f"Attribute type '{attr.type}' does not match expected type '{attr_meta.type}'",
494-
)
495+
attr = ir.convenience.convert_attribute(
496+
attr_name, val, attr_type=attr_meta.type if attr_meta else None
497+
)
495498
return attr
496499

497-
def _translate_docstring(self, node: ast.Expr) -> None:
498-
if hasattr(node.value, "value"):
499-
# python 3.8+
500-
return self.ir_builder.add_docstring(self._current_fn, node.value.value)
501-
raise TypeError(
502-
f"Unexpected type {type(node)!r} for node. Unsupoorted version of python."
503-
)
500+
def _translate_docstring(self, node: ast.FunctionDef) -> None:
501+
if docstring := ast.get_docstring(node):
502+
self._current_fn.doc_string = docstring
504503

505504
def _translate_expr(
506505
self, node: ast.AST, target: Optional[PreferredName] = None
@@ -672,9 +671,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
672671
# Add to sliced_indices, unless it is "::", which is a no-op.
673672
if not (elt.lower is None and elt.upper is None and elt.step is None):
674673
sliced_indices.append((axis, elt))
675-
elif _is_constant_expr(elt) and isinstance(
676-
self._eval_constant_expr(elt), int
677-
):
674+
elif _is_constant_expr(elt) and isinstance(self._eval_constant_expr(elt), int):
678675
scalar_indices.append((axis, elt))
679676
else:
680677
non_scalar_indices.append((axis, elt))
@@ -788,9 +785,7 @@ def _translate_call_expr(self, node: ast.Call):
788785
kwargs: dict[str, ast.expr] = {x.arg: x.value for x in node.keywords}
789786
# First separate inputs from attributes. This is needed because in Python
790787
# it is possible to pass onnx inputs as kwargs
791-
inputs, attrs = _separate_inputs_and_attrs(
792-
op_signature, args, kwargs
793-
)
788+
inputs, attrs = _separate_inputs_and_attrs(op_signature, args, kwargs)
794789
onnx_inputs = [self._translate_opt_expr(x) for x in inputs]
795790
attrs = [
796791
self._translate_attr(x, y, op_signature.params_map[x])
@@ -944,8 +939,6 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
944939
if isinstance(node, (ast.For, ast.While)):
945940
return self._translate_loop_stmt(node)
946941
if ast_utils.is_doc_string(node):
947-
if index_of_stmt == 0:
948-
return self._translate_docstring(node)
949942
return None
950943
if isinstance(node, ast.FunctionDef):
951944
return self._translate_nested_function_def(node)
@@ -1401,12 +1394,16 @@ def _translate_function_signature_common(
14011394

14021395
return self._current_fn
14031396

1404-
def _translate_function_def_common(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
1397+
def _translate_function_def_common(self, node: ast.FunctionDef) -> ir.Function:
14051398
"""Translate a function definition, including the signature and its body."""
1406-
logger.debug("Converter:_translate_function_def_common:%s", fn.name)
1407-
_ = self._translate_function_signature_common(fn)
1408-
for i, s in enumerate(fn.body):
1399+
logger.debug("Converter:_translate_function_def_common:%s", node.name)
1400+
_ = self._translate_function_signature_common(node)
1401+
for i, s in enumerate(node.body):
14091402
self._translate_stmt(s, index_of_stmt=i)
1403+
1404+
# Update docstring if available
1405+
if docstring := ast.get_docstring(node):
1406+
self._current_fn.doc_string = docstring
14101407
return self._current_fn
14111408

14121409
def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
@@ -1453,7 +1450,6 @@ def _is_constant_expr(node: ast.AST) -> bool:
14531450
return False
14541451

14551452

1456-
14571453
def _separate_inputs_and_attrs(
14581454
signature: _schemas.OpSignature,
14591455
args: Sequence[ast.expr],
@@ -1535,9 +1531,8 @@ def _separate_inputs_and_attrs(
15351531
named_attrs[param.name] = attribute
15361532
return tuple(reversed(inputs_reversed)), named_attrs
15371533

1538-
def _to_onnx_ref_attr(
1539-
val: values.AttrRef, info: sourceinfo.SourceInfo | None
1540-
) -> ir.Attr:
1534+
1535+
def _to_onnx_ref_attr(val: values.AttrRef, info: sourceinfo.SourceInfo | None) -> ir.Attr:
15411536
"""Convert an attribute reference to an ONNX ref attribute."""
15421537
pytype = val.typeinfo
15431538
attrtype = _schemas.get_attr_type(pytype)

0 commit comments

Comments
 (0)