@@ -177,7 +177,12 @@ def __init__(
177177
178178 # A stack of functions in the outer scope
179179 self ._outer : list [ir .Function ] = []
180- self ._current_fn : ir .Function | None = None
180+ self ._current_fn : ir .Function = ir .Function (
181+ domain = self ._this_module .domain ,
182+ name = "" ,
183+ graph = ir .Graph ((), (), nodes = []),
184+ attributes = {},
185+ )
181186 self ._nextvar : int = 0
182187 self ._used_vars : set [str ] = set ()
183188 self ._locals : list [dict [str , LocalSymValue ]] = [{}]
@@ -225,13 +230,18 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
225230 def _init_function_translation (self ) -> None :
226231 """Initialize self for translating a new (top-level) function."""
227232 self ._outer = []
228- self ._current_fn = None
233+ # TODO(justinchuby): Update this
234+ self ._current_fn = ir .Function (
235+ domain = self ._this_module .domain ,
236+ name = "" ,
237+ graph = ir .Graph ((), (), nodes = []),
238+ attributes = {},
239+ )
229240 self ._nextvar = 0
230241 self ._used_vars = set ()
231242 self ._locals : List [Dict [str , LocalSymValue ]] = [{}]
232243
233244 def _source_of (self , node : ast .AST ) -> sourceinfo .SourceInfo :
234- assert self ._current_fn is not None
235245 return sourceinfo .SourceInfo (node , self ._source , self ._current_fn .name )
236246
237247 def _message (self , node : ast .AST , error_msg : str ) -> str :
@@ -255,7 +265,6 @@ def _enter_scope(self, name: str, parent_node: ast.AST):
255265 """Enter a control-flow block (a loop body or if-then-else branch).
256266 The block is translated into a nested-scope in ONNX.
257267 """
258- assert self ._current_fn is not None
259268 self ._outer .append (self ._current_fn )
260269 assert self ._this_module is not None
261270 self ._current_fn = ir .Function (
@@ -334,7 +343,9 @@ def _to_onnx_var(
334343 # distinguish between int and bool. So we cast the int tensor to a bool tensor,
335344 # to promote a (python) bool attribute to a ONNX bool tensor.
336345 result_as_bool = self .generate_unique_name (result + "_as_bool" )
337- self .emit ("Cast" , [result ], [result_as_bool ], [ir .AttrInt64 ("to" , ir .DataType .BOOL )])
346+ self .emit (
347+ "Cast" , [result ], [result_as_bool ], [ir .AttrInt64 ("to" , ir .DataType .BOOL )]
348+ )
338349 return Variable (result_as_bool , castable = True )
339350 return Variable (result , castable = True )
340351
@@ -364,7 +375,6 @@ def emit(
364375 attributes = attrs ,
365376 outputs = [self ._lookup (out , self ._source_of (outputs [0 ])) for out in outputs ],
366377 )
367- assert self ._current_fn is not None
368378 self ._current_fn .append (node )
369379
370380 def _emit_const (
@@ -454,12 +464,12 @@ def _translate_attr(
454464 f"Attribute type '{ attr_ref .type } ' does not match expected type '{ attr_meta .type } '" ,
455465 )
456466 return attr_ref
457- if isinstance (val , ir .Function ):
458- # if isinstance(val, irbuilder.IRFunction):
467+ if isinstance (val , ir .Graph ):
468+ # if isinstance(val, irbuilder.IRFunction):
459469 # Check that outer-scope variables referenced by function have same value
460470 # at function-definition site and use-as-attribute site, to avoid errors.
461471
462- # TODO(justinchuby): Capture outer_scope_variables
472+ # TODO(justinchuby): Capture outer_scope_variables?
463473 # And implement the following
464474 # for pyvar, previous in val.outer_scope_variables:
465475 # current = self._lookup(pyvar, self._source_of(expr))
@@ -470,9 +480,8 @@ def _translate_attr(
470480 # f"'{expr.id!r}' modified.",
471481 # )
472482
473- # Create GraphProto attribute
474- # TODO: Fix this
475- val = val .to_graph_proto ()
483+ # Create Graph attribute
484+ pass
476485 else :
477486 val = self ._eval_constant_expr (expr )
478487
@@ -482,25 +491,15 @@ def _translate_attr(
482491 # The caller is responsible for omitting such attribute-values from the list of attributes
483492 # in a NodeProto.
484493 if val is None :
485- if attr_meta and attr_meta .required :
486- self .fail (expr , f"Attribute '{ attr_name } ' is required." )
487494 return None
488- attr_type = int (attr_meta .type ) if attr_meta else None
489- attr = self ._make_onnx_attr (attr_name , val , attrtype = attr_type )
490- if attr_meta and (attr .type != attr_meta .type ):
491- self .fail (
492- expr ,
493- f"Attribute type '{ attr .type } ' does not match expected type '{ attr_meta .type } '" ,
494- )
495+ attr = ir .convenience .convert_attribute (
496+ attr_name , val , attr_type = attr_meta .type if attr_meta else None
497+ )
495498 return attr
496499
497- def _translate_docstring (self , node : ast .Expr ) -> None :
498- if hasattr (node .value , "value" ):
499- # python 3.8+
500- return self .ir_builder .add_docstring (self ._current_fn , node .value .value )
501- raise TypeError (
502- f"Unexpected type { type (node )!r} for node. Unsupoorted version of python."
503- )
500+ def _translate_docstring (self , node : ast .FunctionDef ) -> None :
501+ if docstring := ast .get_docstring (node ):
502+ self ._current_fn .doc_string = docstring
504503
505504 def _translate_expr (
506505 self , node : ast .AST , target : Optional [PreferredName ] = None
@@ -672,9 +671,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
672671 # Add to sliced_indices, unless it is "::", which is a no-op.
673672 if not (elt .lower is None and elt .upper is None and elt .step is None ):
674673 sliced_indices .append ((axis , elt ))
675- elif _is_constant_expr (elt ) and isinstance (
676- self ._eval_constant_expr (elt ), int
677- ):
674+ elif _is_constant_expr (elt ) and isinstance (self ._eval_constant_expr (elt ), int ):
678675 scalar_indices .append ((axis , elt ))
679676 else :
680677 non_scalar_indices .append ((axis , elt ))
@@ -788,9 +785,7 @@ def _translate_call_expr(self, node: ast.Call):
788785 kwargs : dict [str , ast .expr ] = {x .arg : x .value for x in node .keywords }
789786 # First separate inputs from attributes. This is needed because in Python
790787 # it is possible to pass onnx inputs as kwargs
791- inputs , attrs = _separate_inputs_and_attrs (
792- op_signature , args , kwargs
793- )
788+ inputs , attrs = _separate_inputs_and_attrs (op_signature , args , kwargs )
794789 onnx_inputs = [self ._translate_opt_expr (x ) for x in inputs ]
795790 attrs = [
796791 self ._translate_attr (x , y , op_signature .params_map [x ])
@@ -944,8 +939,6 @@ def _translate_stmt(self, node: ast.stmt, index_of_stmt=None) -> None:
944939 if isinstance (node , (ast .For , ast .While )):
945940 return self ._translate_loop_stmt (node )
946941 if ast_utils .is_doc_string (node ):
947- if index_of_stmt == 0 :
948- return self ._translate_docstring (node )
949942 return None
950943 if isinstance (node , ast .FunctionDef ):
951944 return self ._translate_nested_function_def (node )
@@ -1401,12 +1394,16 @@ def _translate_function_signature_common(
14011394
14021395 return self ._current_fn
14031396
1404- def _translate_function_def_common (self , fn : ast .FunctionDef ) -> irbuilder . IRFunction :
1397+ def _translate_function_def_common (self , node : ast .FunctionDef ) -> ir . Function :
14051398 """Translate a function definition, including the signature and its body."""
1406- logger .debug ("Converter:_translate_function_def_common:%s" , fn .name )
1407- _ = self ._translate_function_signature_common (fn )
1408- for i , s in enumerate (fn .body ):
1399+ logger .debug ("Converter:_translate_function_def_common:%s" , node .name )
1400+ _ = self ._translate_function_signature_common (node )
1401+ for i , s in enumerate (node .body ):
14091402 self ._translate_stmt (s , index_of_stmt = i )
1403+
1404+ # Update docstring if available
1405+ if docstring := ast .get_docstring (node ):
1406+ self ._current_fn .doc_string = docstring
14101407 return self ._current_fn
14111408
14121409 def translate_function_def (self , stmt : ast .FunctionDef ) -> irbuilder .IRFunction :
@@ -1453,7 +1450,6 @@ def _is_constant_expr(node: ast.AST) -> bool:
14531450 return False
14541451
14551452
1456-
14571453def _separate_inputs_and_attrs (
14581454 signature : _schemas .OpSignature ,
14591455 args : Sequence [ast .expr ],
@@ -1535,9 +1531,8 @@ def _separate_inputs_and_attrs(
15351531 named_attrs [param .name ] = attribute
15361532 return tuple (reversed (inputs_reversed )), named_attrs
15371533
1538- def _to_onnx_ref_attr (
1539- val : values .AttrRef , info : sourceinfo .SourceInfo | None
1540- ) -> ir .Attr :
1534+
1535+ def _to_onnx_ref_attr (val : values .AttrRef , info : sourceinfo .SourceInfo | None ) -> ir .Attr :
15411536 """Convert an attribute reference to an ONNX ref attribute."""
15421537 pytype = val .typeinfo
15431538 attrtype = _schemas .get_attr_type (pytype )
0 commit comments