Skip to content

Commit be610e9

Browse files
committed
continue
Signed-off-by: Justin Chu <[email protected]>
1 parent 852cc42 commit be610e9

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

onnxscript/_converter.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None:
229229
or opset.version != self._default_opset.version
230230
):
231231
self.fail(
232-
node, f"Two distincts opset were used ({opset} != {self._default_opset})."
232+
node, f"Two distinct opset were used ({opset} != {self._default_opset})."
233233
)
234234
else:
235235
self._default_opset = opset
@@ -251,19 +251,19 @@ def _find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:
251251
return res
252252
return None
253253

254-
def _init_function_translation(self) -> None:
255-
"""Initialize self for translating a new (top-level) function."""
256-
self._outer = []
257-
# TODO(justinchuby): Update this
258-
self._current_fn = ir.Function(
259-
domain=self._opset.domain,
260-
name="",
261-
graph=ir.Graph((), (), nodes=[]),
262-
attributes={},
263-
)
264-
self._nextvar = 0
265-
self._used_vars = set()
266-
self._locals: List[Dict[str, LocalSymValue]] = [{}]
254+
# def _init_function_translation(self) -> None:
255+
# """Initialize self for translating a new (top-level) function."""
256+
# self._outer = []
257+
# # TODO(justinchuby): Update this
258+
# self._current_fn = ir.Function(
259+
# domain=self._opset.domain,
260+
# name="",
261+
# graph=ir.Graph((), (), nodes=[]),
262+
# attributes={},
263+
# )
264+
# self._nextvar = 0
265+
# self._used_vars = set()
266+
# self._locals: List[Dict[str, LocalSymValue]] = [{}]
267267

268268
def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
269269
return sourceinfo.SourceInfo(node, self._source, self._current_fn.name)
@@ -328,7 +328,7 @@ def _lookup(
328328
raise ValueError(info.msg(f"Unbound name: {name}."))
329329
return None
330330

331-
def generate_unique_name(self, candidate: str = "tmp") -> str:
331+
def _generate_unique_name(self, candidate: str = "tmp") -> str:
332332
# TODO(justinchuby): Can we reduce the O complexity of this function?
333333
r = candidate
334334
while r in self._used_vars:
@@ -347,14 +347,14 @@ def _to_onnx_var(
347347
"""Convert a value to an ONNX variable."""
348348
if isinstance(val, values.AttrRef):
349349
# promote attribute to value
350-
result = self.generate_unique_name(target)
350+
result = self._generate_unique_name(target)
351351
attr = _to_onnx_ref_attr(val, info)
352352
self.emit("Constant", [], [result], [attr])
353353
if ta.base_type_is_bool(val.typeinfo):
354354
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
355355
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
356356
# to promote a (python) bool attribute to a ONNX bool tensor.
357-
result_as_bool = self.generate_unique_name(result + "_as_bool")
357+
result_as_bool = self._generate_unique_name(result + "_as_bool")
358358
self.emit(
359359
"Cast", [result], [result_as_bool], [ir.AttrInt64("to", ir.DataType.BOOL)]
360360
)
@@ -406,7 +406,7 @@ def _emit_const(
406406
suggested_name = f"int64_{pyvalue[0]}_1d"
407407
else:
408408
suggested_name = "const"
409-
var_name = self.generate_unique_name(suggested_name)
409+
var_name = self._generate_unique_name(suggested_name)
410410

411411
# Create a tensor from the python value
412412
try:
@@ -419,7 +419,7 @@ def _emit_const(
419419

420420
def _emit_copy(self, original_var: str, suggested_name: str) -> str:
421421
"""Emits a copy statement, using the ONNX Identity operator."""
422-
new_var = self.generate_unique_name(suggested_name)
422+
new_var = self._generate_unique_name(suggested_name)
423423
self.emit("Identity", [original_var], [new_var])
424424
return new_var
425425

@@ -539,7 +539,7 @@ def _translate_expr(
539539
callee, args, attrs = r
540540
target = "tmp" if target is None else target
541541
assert isinstance(target, str)
542-
result = self.generate_unique_name(target)
542+
result = self._generate_unique_name(target)
543543
self.emit([result], callee, args, attrs)
544544
return Variable(result)
545545

@@ -594,7 +594,7 @@ def _translate_subscript_expr(
594594
var_name = var.name
595595
if target is None:
596596
target = f"{var_name}_subscripted"
597-
target = self.generate_unique_name(target)
597+
target = self._generate_unique_name(target)
598598
indices = ast_utils.normalize_subscript_expr(node)
599599
info = self._source_of(node.slice)
600600

@@ -635,7 +635,7 @@ def translate_slice_component(
635635
raise RuntimeError(f"Slice component type must be int, not {type(cst)}")
636636
else:
637637
name = self._translate_expr(node_arg).name
638-
reshaped = self.generate_unique_name(f"{name}_reshaped")
638+
reshaped = self._generate_unique_name(f"{name}_reshaped")
639639
self.emit(
640640
[reshaped],
641641
values.Op(self.default_opset, "Reshape"),
@@ -721,16 +721,16 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
721721

722722
if len(starts) > 1:
723723
axis_0_attr = self._make_onnx_attr("axis", 0)
724-
start_name = self.generate_unique_name(f"{var_name}_start")
724+
start_name = self._generate_unique_name(f"{var_name}_start")
725725
self.emit([start_name], "Concat", starts, [axis_0_attr])
726726

727-
end_name = self.generate_unique_name(f"{var_name}_end")
727+
end_name = self._generate_unique_name(f"{var_name}_end")
728728
self.emit([end_name], "Concat", ends, [axis_0_attr])
729729

730-
axes_name = self.generate_unique_name(f"{var_name}_axis")
730+
axes_name = self._generate_unique_name(f"{var_name}_axis")
731731
self.emit([axes_name], "Concat", axes, [axis_0_attr])
732732

733-
steps_name = self.generate_unique_name(f"{var_name}_step")
733+
steps_name = self._generate_unique_name(f"{var_name}_step")
734734
self.emit([steps_name], "Concat", steps, [axis_0_attr])
735735
else:
736736
start_name = starts[0]
@@ -739,7 +739,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
739739
steps_name = steps[0]
740740

741741
if squeezed_axes:
742-
sliced_name = self.generate_unique_name(f"{var_name}_sliced")
742+
sliced_name = self._generate_unique_name(f"{var_name}_sliced")
743743
self.emit(
744744
[sliced_name],
745745
"Slice",
@@ -748,14 +748,14 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
748748
squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info)
749749

750750
if non_scalar_indices: # use temporary to store result of squeeze
751-
result = self.generate_unique_name(f"{var_name}_squeezed")
751+
result = self._generate_unique_name(f"{var_name}_squeezed")
752752
else: # store squeezed result in final target
753753
result = target
754754

755755
self.emit([result], "Squeeze", [sliced_name, squeezed_axes])
756756
else:
757757
if non_scalar_indices: # use temporary to store result of Slice
758-
result = self.generate_unique_name(f"{var_name}_sliced")
758+
result = self._generate_unique_name(f"{var_name}_sliced")
759759
else: # store result of Slice in final target
760760
result = target
761761
slice_inputs = [var_name, start_name, end_name, axes_name, steps_name]
@@ -774,7 +774,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:
774774
# use Gather to perform indexing
775775
# Assign gathered value to either temporary or final target
776776
if axis != last_axis: # use temporary to store result of Gather
777-
gathered = self.generate_unique_name(f"{var_name}_axis_{axis}")
777+
gathered = self._generate_unique_name(f"{var_name}_axis_{axis}")
778778
else: # store result of Gather in final target
779779
gathered = target
780780
self.emit([gathered], "Gather", [str(result), index_value], [axis_attr])
@@ -876,7 +876,7 @@ def _translate_compare_expr(self, node):
876876
op = values.Op(self.default_opset, opname if opname != "NotEqual" else "Equal")
877877
left, right = self._cast_like_binary_expression(op, left, right)
878878
if opname == "NotEqual":
879-
tmp = self.generate_unique_name()
879+
tmp = self._generate_unique_name()
880880
self.emit([tmp], op, [left, right])
881881
not_op = values.Op(self.default_opset, "Not")
882882
return not_op, [tmp], []
@@ -979,7 +979,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
979979
def generate_onnx_name(x: ast.AST):
980980
if not isinstance(x, ast.Name):
981981
self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'")
982-
onnx_name = self.generate_unique_name(x.id)
982+
onnx_name = self._generate_unique_name(x.id)
983983
self._bind(
984984
x.id,
985985
values.Dynamic(
@@ -1078,7 +1078,7 @@ def _translate_if_stmt(self, stmt: ast.If) -> None:
10781078
elseAttr = self._make_onnx_attr("else_branch", elseGraph)
10791079

10801080
def rename(x):
1081-
r = self.generate_unique_name(x)
1081+
r = self._generate_unique_name(x)
10821082
self._bind(
10831083
x,
10841084
values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)),
@@ -1122,7 +1122,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11221122
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
11231123
assert not iter.keywords, "Unsupported loop bound."
11241124
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name
1125-
o_cond_var = self.generate_unique_name("cond_in")
1125+
o_cond_var = self._generate_unique_name("cond_in")
11261126
i_cond_var = o_cond_var
11271127
cond_while = None
11281128
o_loop_condition = "" # No condition for a for loop.
@@ -1156,7 +1156,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11561156

11571157
# build loop_body
11581158
self._enter_scope("loop_body", loop_stmt)
1159-
o_loop_var = self.generate_unique_name(p_loop_var)
1159+
o_loop_var = self._generate_unique_name(p_loop_var)
11601160
self.ir_builder.add_input(
11611161
self._current_fn,
11621162
o_loop_var,
@@ -1176,7 +1176,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
11761176
)
11771177

11781178
for pv in loop_state_vars:
1179-
ov = self.generate_unique_name(pv)
1179+
ov = self._generate_unique_name(pv)
11801180
# TODO: retrieve the annotation for variable pv is any is specified.
11811181
# typeinfo = self._eval_constant_expr(pv.annotation)
11821182
typeinfo = None
@@ -1217,7 +1217,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12171217
continue
12181218
self._translate_stmt(s)
12191219

1220-
o_cond_out = self.generate_unique_name("cond_out")
1220+
o_cond_out = self._generate_unique_name("cond_out")
12211221

12221222
if cond_while is not None:
12231223
# Loop while
@@ -1267,7 +1267,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12671267
info = self._source_of(loop_stmt)
12681268

12691269
def rename(x):
1270-
r = self.generate_unique_name(x)
1270+
r = self._generate_unique_name(x)
12711271
self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info))
12721272
return r
12731273

0 commit comments

Comments
 (0)