@@ -339,22 +339,20 @@ def tensor_name_generator() -> str:
339339 def _to_onnx_attr_ref (
340340 self , val : values .AttrRef , info : Optional [sourceinfo .SourceInfo ]
341341 ) -> ir .Attr :
342- pytype = val .typeinfo
343- attrtype = ta .pytype_to_attrtype (pytype )
342+ attrtype = val .value .type
344343 attrname = None
345- if attrtype is onnx .AttributeProto .FLOAT :
344+ if attrtype is ir . AttributeType . FLOAT : # onnx.AttributeProto.FLOAT:
346345 attrname = "value_float"
347- elif attrtype is onnx . AttributeProto .INT :
346+ elif attrtype is ir . AttributeType .INT :
348347 attrname = "value_int"
349- elif attrtype is onnx . AttributeProto .STRING :
348+ elif attrtype is ir . AttributeType .STRING :
350349 attrname = "value_string"
351- elif attrtype is onnx . AttributeProto .INTS :
350+ elif attrtype is ir . AttributeType .INTS :
352351 attrname = "value_ints"
353352 else :
354- msg = f"Unsupported attribute type { pytype !r} ."
353+ msg = f"Unsupported attribute type { attrtype !r} ."
355354 fail (info .msg (msg ) if info else msg )
356- attr_type = ir .AttributeType (ta .pytype_to_attrtype (pytype ))
357- return ir .Attr (attrname , attr_type , None , val .value )
355+ return ir .Attr (attrname , attrtype , value = None , ref_attr_name = val .value .name )
358356
359357 def _to_onnx_var (
360358 self ,
@@ -369,7 +367,7 @@ def _to_onnx_var(
369367 result = self .emit (
370368 [result_name ], values .Op (self .default_opset , "Constant" ), [], [attr ]
371369 )
372- if ta . base_type_is_bool ( val .typeinfo ) :
370+ if val .as_bool :
373371 # ONNX attributes use an int-encoding for bools, but ONNX tensor types
374372 # distinguish between int and bool. So we cast the int tensor to a bool tensor,
375373 # to promote a (python) bool attribute to a ONNX bool tensor.
@@ -384,8 +382,9 @@ def _to_onnx_var(
384382 )
385383 self ._castable .add (result_name )
386384 return result
387- if isinstance (val , values .Dynamic ):
388- return val .value
385+ if isinstance (val , values .SymbolValue ):
386+ if isinstance (val .value , ir .Value ):
387+ return val .value
389388 # Assume value is a python-value convertible to a tensor
390389 # TODO: check if value is convertible to a TensorProto, so that we can
391390 # produce a better error _message otherwise
@@ -534,29 +533,44 @@ def _translate_attr(
534533
535534 if isinstance (expr , ast .Name ):
536535 val = self ._lookup (expr .id , self ._source_of (expr ))
537- if isinstance (val , values .AttrRef ):
538- attr_type = ir . AttributeType ( ta . pytype_to_attrtype ( val .typeinfo ))
539- attr_ref = ir . Attr ( attr_name , attr_type , None , val . value )
540- if attr_meta is not None and ( attr_ref . type != attr_meta . type ) :
541- self . fail (
542- expr ,
543- f"Attribute type ' { attr_ref .type } ' does not match expected type ' { attr_meta . type } '" ,
536+ if isinstance (val , values .SymbolValue ):
537+ val = val .value
538+ if isinstance ( val , ir . Attr ):
539+ # A reference to an attribute parameter :
540+ attr = val
541+ attr_ref = ir . Attr (
542+ attr_name , attr .type , value = None , ref_attr_name = attr . name
544543 )
545- return attr_ref
546- if isinstance (val , irbuilder .IRFunction ):
547- # Check that outer-scope variables referenced by function have same value
548- # at function-definition site and use-as-attribute site, to avoid errors.
549- for pyvar , previous in val .outer_scope_variables :
550- current = self ._lookup (pyvar , self ._source_of (expr ))
551- if current .value != previous .value :
544+ if attr_meta is not None and (attr .type != attr_meta .type ):
552545 self .fail (
553546 expr ,
554- f"Outer scope variable '{ pyvar } ' referenced by function "
555- f"'{ expr .id !r} ' modified." ,
547+ f"Attribute type '{ attr_ref .type } ' does not match expected type '{ attr_meta .type } '" ,
556548 )
557-
558- # Create GraphProto attribute
559- val = val .to_graph_proto ()
549+ return attr_ref
550+ if isinstance (val , irbuilder .IRFunction ):
551+ # A reference to a nested-function: convert to GraphProto and use it.
552+ irfunction = val
553+ # Check that outer-scope variables referenced by function have same value
554+ # at function-definition site and use-as-attribute site, to avoid errors.
555+ for pyvar , previous in irfunction .outer_scope_variables :
556+ current = self ._lookup (pyvar , self ._source_of (expr ))
557+ if current .value != previous .value :
558+ self .fail (
559+ expr ,
560+ f"Outer scope variable '{ pyvar } ' referenced by function "
561+ f"'{ expr .id !r} ' modified." ,
562+ )
563+ # Create GraphProto attribute
564+ val = irfunction .to_graph_proto ()
565+ if isinstance (val , ir .Value ):
566+ self .fail (expr , f"Cannot use ir.Value '{ expr .id } ' as an attribute." )
567+ else :
568+ # Treat as a constant python-value, to be converted below.
569+ pass
570+ else :
571+ # This must be a reference to an outer-scope python-value, typically a constant.
572+ # The value will be converted to an ONNX attribute value below.
573+ pass
560574 else :
561575 val = self ._eval_constant_expr (expr )
562576
@@ -1045,7 +1059,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
10451059 typeinfo = None
10461060 if typeinfo is not None :
10471061 set_type_info (t , typeinfo )
1048- var = values .Dynamic (t , values . DynamicKind . Intermediate , info , typeinfo )
1062+ var = values .SymbolValue (t , info )
10491063 self ._bind (lhs , var )
10501064 elif isinstance (lhs , ast .Tuple ):
10511065 # Assignments of the form "x, y, z = op.SomeOp(...)"
@@ -1068,9 +1082,7 @@ def generate_onnx_name(x: ast.AST):
10681082 for x , output in zip (lhs .elts , outputs ):
10691083 self ._bind (
10701084 x .id ,
1071- values .Dynamic (
1072- output , values .DynamicKind .Intermediate , self ._source_of (x )
1073- ),
1085+ values .SymbolValue (output , self ._source_of (x )),
10741086 )
10751087 else :
10761088 self .fail (lhs , f"Unsupported construct in LHS of assignment: '{ type (lhs )!r} '" )
@@ -1117,10 +1129,11 @@ def ret(exp, i, suffix):
11171129 preferred_name = f"return_val{ suffix } "
11181130 return_var = self ._translate_expr (exp , preferred_name ) # TODO(rama)
11191131 val = self ._lookup (return_var .name , self ._source_of (exp ), False )
1120- if val and val .kind == values .DynamicKind .Input :
1121- # In ONNX, a graph-input cannot be an output of the graph.
1122- # We need to insert a copy.
1123- return_var = self ._emit_copy (return_var , preferred_name )
1132+ if isinstance (val , values .SymbolValue ) and isinstance (val .value , ir .Value ):
1133+ if val .value .is_graph_input ():
1134+ # In ONNX, a graph-input cannot be an output of the graph.
1135+ # We need to insert a copy.
1136+ return_var = self ._emit_copy (return_var , preferred_name )
11241137 for prev_output in self ._current_fn .outputs :
11251138 if prev_output .name == return_var .name :
11261139 # ONNX does not allow duplicate output names.
@@ -1190,7 +1203,7 @@ def rename(x):
11901203 for x , y in zip (live_defs , if_outputs ):
11911204 self ._bind (
11921205 x ,
1193- values .Dynamic ( y , values . DynamicKind . Intermediate , self ._source_of (stmt )),
1206+ values .SymbolValue ( y , self ._source_of (stmt )),
11941207 )
11951208
11961209 def _translate_loop_stmt (self , loop_stmt : Union [ast .For , ast .While ]) -> None :
@@ -1257,7 +1270,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12571270 self ._current_fn .append_parameter (onnx_loop_var )
12581271 self ._bind (
12591272 python_loop_var_name ,
1260- values .Dynamic (onnx_loop_var , values . DynamicKind . Loop , self ._source_of (loop_stmt )),
1273+ values .SymbolValue (onnx_loop_var , self ._source_of (loop_stmt )),
12611274 )
12621275
12631276 self ._current_fn .append_parameter (
@@ -1278,9 +1291,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12781291 )
12791292 self ._bind (
12801293 pv ,
1281- values .Dynamic (
1294+ values .SymbolValue (
12821295 ir .Value (name = onnx_var_name ),
1283- values .DynamicKind .Loop ,
12841296 self ._source_of (loop_stmt ),
12851297 ),
12861298 )
@@ -1376,7 +1388,7 @@ def rename(x):
13761388 if isinstance (loop_outputs , ir .Value ):
13771389 loop_outputs = [loop_outputs ]
13781390 for x , loop_output in zip (outputs , loop_outputs ):
1379- self ._bind (x , values .Dynamic (loop_output , values . DynamicKind . Output , info ))
1391+ self ._bind (x , values .SymbolValue (loop_output , info ))
13801392
13811393 def _translate_block (
13821394 self ,
@@ -1431,7 +1443,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
14311443 function_ir .outer_scope_variables = [
14321444 (var , self ._lookup (var , self ._source_of (fn ))) for var in outer_scope_vars
14331445 ]
1434- self ._bind (fn .name , function_ir )
1446+ self ._bind (fn .name , values . SymbolValue ( function_ir , self . _source_of ( fn )) )
14351447 # TODO: Does not yet handle nested functions within nested functions.
14361448 self ._current_fn .add_nested_function (function_ir )
14371449
@@ -1459,16 +1471,15 @@ def _translate_function_signature_common(
14591471 attribute_type = ta .pytype_to_attrtype (typeinfo )
14601472 attr = ir .Attr (x .arg , ir .AttributeType (attribute_type ), default_value , None )
14611473 self ._current_fn .append_parameter (attr )
1462- self ._bind (x .arg , values .AttrRef (x .arg , typeinfo , self ._source_of (x )))
1474+ as_bool = ta .base_type_is_bool (typeinfo )
1475+ self ._bind (x .arg , values .AttrRef (attr , as_bool , self ._source_of (x )))
14631476 else :
14641477 onnx_parameter = make_value (x .arg , typeinfo , self ._source_of (x ))
14651478 self ._current_fn .append_parameter (onnx_parameter )
14661479 self ._used_vars .add (x .arg )
14671480 self ._bind (
14681481 x .arg ,
1469- values .Dynamic (
1470- onnx_parameter , values .DynamicKind .Input , self ._source_of (x )
1471- ),
1482+ values .SymbolValue (onnx_parameter , self ._source_of (x )),
14721483 )
14731484 if fn .returns :
14741485 type_annotation = self ._eval_constant_expr (fn .returns )
0 commit comments