Skip to content

Commit c84ad91

Browse files
committed
update emit call
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent c53b8f3 commit c84ad91

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

onnxscript/_converter.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,17 @@ def _to_onnx_var(
331331
# promote attribute to value
332332
result = self._generate_unique_name(target)
333333
attr = _to_onnx_ref_attr(val, info)
334-
self.emit([], "Constant", [result], [attr])
334+
self.emit([], "Constant", [result], attrs=[attr])
335335
if ta.base_type_is_bool(val.typeinfo):
336336
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
337337
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
338338
# to promote a (python) bool attribute to a ONNX bool tensor.
339339
result_as_bool = self._generate_unique_name(result + "_as_bool")
340340
self.emit(
341-
[result], "Cast", [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)]
341+
[result],
342+
"Cast",
343+
[result_as_bool],
344+
attrs=[ir.AttrInt64("to", ir.DataType.BOOL)],
342345
)
343346
return Variable(result_as_bool, castable=True)
344347
return Variable(result, castable=True)
@@ -358,6 +361,7 @@ def emit(
358361
outputs: Sequence[str],
359362
op_type: str,
360363
inputs: Sequence[str],
364+
*,
361365
attrs: Sequence[ir.Attr] = (),
362366
domain: str = "",
363367
):
@@ -396,7 +400,7 @@ def _emit_const(
396400
except Exception as e:
397401
fail(info.msg(str(e)))
398402

399-
self.emit([], "Constant", [var_name], [ir.AttrTensor("value", tensor)])
403+
self.emit([], "Constant", [var_name], attrs=[ir.AttrTensor("value", tensor)])
400404
return Variable(var_name, True)
401405

402406
def _emit_copy(self, original_var: str, suggested_name: str) -> str:
@@ -522,7 +526,7 @@ def _translate_expr(
522526
target = "tmp" if target is None else target
523527
assert isinstance(target, str)
524528
result = self._generate_unique_name(target)
525-
self.emit([result], callee, args, attrs)
529+
self.emit([result], callee, args, attrs=attrs)
526530
return Variable(result)
527531

528532
def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]:
@@ -620,9 +624,8 @@ def translate_slice_component(
620624
reshaped = self._generate_unique_name(f"{name}_reshaped")
621625
self.emit(
622626
[reshaped],
623-
values.Op(self._default_opset, "Reshape"),
627+
"Reshape",
624628
[name, one_1d().name],
625-
[],
626629
)
627630
return reshaped, None
628631

@@ -704,16 +707,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
704707
if len(starts) > 1:
705708
axis_0_attr = self._make_onnx_attr("axis", 0)
706709
start_name = self._generate_unique_name(f"{var_name}_start")
707-
self.emit([start_name], "Concat", starts, [axis_0_attr])
710+
self.emit([start_name], "Concat", starts, attrs=[axis_0_attr])
708711

709712
end_name = self._generate_unique_name(f"{var_name}_end")
710-
self.emit([end_name], "Concat", ends, [axis_0_attr])
713+
self.emit([end_name], "Concat", ends, attrs=[axis_0_attr])
711714

712715
axes_name = self._generate_unique_name(f"{var_name}_axis")
713-
self.emit([axes_name], "Concat", axes, [axis_0_attr])
716+
self.emit([axes_name], "Concat", axes, attrs=[axis_0_attr])
714717

715718
steps_name = self._generate_unique_name(f"{var_name}_step")
716-
self.emit([steps_name], "Concat", steps, [axis_0_attr])
719+
self.emit([steps_name], "Concat", steps, attrs=[axis_0_attr])
717720
else:
718721
start_name = starts[0]
719722
end_name = ends[0]
@@ -759,7 +762,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
759762
gathered = self._generate_unique_name(f"{var_name}_axis_{axis}")
760763
else: # store result of Gather in final target
761764
gathered = target
762-
self.emit([gathered], "Gather", [str(result), index_value], [axis_attr])
765+
self.emit([gathered], "Gather", [str(result), index_value], attrs=[axis_attr])
763766
result = gathered
764767

765768
return Variable(result)
@@ -971,7 +974,7 @@ def generate_onnx_name(x: ast.AST):
971974
return onnx_name
972975

973976
outputs = [generate_onnx_name(x) for x in lhs.elts]
974-
self.emit(outputs, callee, inputs, attrs)
977+
self.emit(outputs, callee, inputs, attrs=attrs)
975978
else:
976979
self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'")
977980

@@ -1085,7 +1088,7 @@ def rename(x):
10851088
[test],
10861089
"If",
10871090
renamed,
1088-
[then_attr, else_attr],
1091+
attrs=[then_attr, else_attr],
10891092
)
10901093

10911094
def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
@@ -1218,9 +1221,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12181221

12191222
self.emit(
12201223
[o_cond_out],
1221-
values.Op(self._default_opset, operator_name),
1224+
operator_name,
12221225
[condition_name or o_cond_var],
1223-
[],
12241226
)
12251227

12261228
self.ir_builder.add_output(
@@ -1262,8 +1264,8 @@ def rename(x):
12621264
onnx_outputs,
12631265
"Loop",
12641266
inputs,
1265-
attrs,
1266-
sub_functions=sub_functions,
1267+
attrs=attrs,
1268+
# sub_functions=sub_functions,
12671269
)
12681270

12691271
def _translate_block(

0 commit comments

Comments
 (0)