Skip to content

Commit 31269f0

Browse files
committed
Cleanup some value creation
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent f5db9d8 commit 31269f0

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

onnxscript/_internal/converter.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12021202
if isinstance(loop_stmt, ast.For):
12031203
if not isinstance(loop_stmt.target, ast.Name):
12041204
self.fail(loop_stmt, "For loop target must be a single variable.")
1205-
p_loop_var = loop_stmt.target.id
1205+
python_loop_var_name = loop_stmt.target.id
12061206
# iter
12071207
iter = loop_stmt.iter
12081208
assert isinstance(iter, ast.Call), "Loop bound not a call."
@@ -1216,8 +1216,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12161216
self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.")
12171217
assert not iter.keywords, "Unsupported loop bound."
12181218
o_loop_bound = self._translate_expr(iter.args[0], "loop_bound")
1219-
o_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
1220-
i_cond_var = o_cond_var
1219+
onnx_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama)
1220+
i_cond_var = onnx_cond_var
12211221
cond_while = None
12221222
o_loop_condition = None # No condition for a for loop.
12231223
elif isinstance(loop_stmt, ast.While):
@@ -1228,11 +1228,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12281228
"Unexpected condition type {type(loop_stmt)!r} for a while loop, "
12291229
"it should be 'while <condition_name>:'.",
12301230
)
1231-
p_loop_var = "infinite_loop"
1231+
python_loop_var_name = "infinite_loop"
12321232
o_loop_bound = None
12331233
i_cond_var = ir.Value(name=test.id) # TODO(Rama)
12341234
cond_while = ir.Value(name=test.id) # TODO(Rama)
1235-
o_cond_var = None
1235+
onnx_cond_var = None
12361236
o_loop_condition = self._translate_name_expr(test)
12371237
# we need to go through all the instructions to see
12381238
# which instruction defines the condition test.id
@@ -1252,19 +1252,16 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12521252

12531253
# build loop_body
12541254
self._enter_scope("loop_body", loop_stmt)
1255-
o_loop_var = self.generate_unique_name(p_loop_var)
1256-
self._current_fn.append_parameter(
1257-
make_value(
1258-
o_loop_var,
1259-
onnx_types.INT64,
1260-
self._source_of(loop_stmt),
1261-
)
1255+
onnx_loop_var_name = self.generate_unique_name(python_loop_var_name)
1256+
onnx_loop_var = make_value(
1257+
onnx_loop_var_name,
1258+
onnx_types.INT64,
1259+
self._source_of(loop_stmt),
12621260
)
1261+
self._current_fn.append_parameter(onnx_loop_var)
12631262
self._bind(
1264-
p_loop_var,
1265-
values.Dynamic(
1266-
ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt)
1267-
),
1263+
python_loop_var_name,
1264+
values.Dynamic(onnx_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
12681265
)
12691266

12701267
self._current_fn.append_parameter(
@@ -1276,17 +1273,19 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
12761273
)
12771274

12781275
for pv in loop_state_vars:
1279-
ov = self.generate_unique_name(pv)
1276+
onnx_var_name = self.generate_unique_name(pv)
12801277
# TODO: retrieve the annotation for variable pv is any is specified.
12811278
# typeinfo = self._eval_constant_expr(pv.annotation)
12821279
typeinfo = None
12831280
self._current_fn.append_parameter(
1284-
make_value(ov, typeinfo, self._source_of(loop_stmt))
1281+
make_value(onnx_var_name, typeinfo, self._source_of(loop_stmt))
12851282
)
12861283
self._bind(
12871284
pv,
12881285
values.Dynamic(
1289-
ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt)
1286+
ir.Value(name=onnx_var_name),
1287+
values.DynamicKind.Loop,
1288+
self._source_of(loop_stmt),
12901289
),
12911290
)
12921291

@@ -1319,7 +1318,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13191318
continue
13201319
self._translate_stmt(s)
13211320

1322-
o_cond_out = self.generate_unique_name("cond_out")
1321+
onnx_cond_out_name = self.generate_unique_name("cond_out")
13231322

13241323
if cond_while is not None:
13251324
# Loop while
@@ -1330,35 +1329,35 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
13301329
f"Unable to find condition variable {cond_while.name} in known "
13311330
f"variables {list(current_scope)!r}.",
13321331
)
1333-
o_cond_var = current_scope[cond_while.name].value
1332+
onnx_cond_var = current_scope[cond_while.name].value
13341333

13351334
self.emit(
1336-
[o_cond_out],
1335+
[onnx_cond_out_name],
13371336
values.Op(self.default_opset, operator_name),
1338-
[condition_name or o_cond_var],
1337+
[condition_name or onnx_cond_var],
13391338
[],
13401339
)
13411340

13421341
self._current_fn.outputs.append(
13431342
make_value(
1344-
o_cond_out,
1343+
onnx_cond_out_name,
13451344
onnx_types.BOOL,
13461345
self._source_of(loop_stmt),
13471346
)
13481347
)
13491348
for pv in loop_state_vars:
1350-
ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
1351-
if ov.name not in self._current_fn.assigned_names:
1349+
onnx_var = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt))
1350+
if onnx_var.name not in self._current_fn.assigned_names:
13521351
# When converting the loop-body into a graph, we need to handle
13531352
# identity assignments of the form "x = y" inside the loop body
13541353
# specially if y represents a value computed outside the loop body.
13551354
# In this case, we create a copy of y, treating the statement as
13561355
# shorthand for "x = op.Identity(y)".
1357-
ov = self._emit_copy(ov, pv)
1356+
onnx_var = self._emit_copy(onnx_var, pv)
13581357
# TODO: retrieve variable type for the annotation if any.
13591358
typeinfo = None
13601359
self._current_fn.outputs.append(
1361-
make_value(ov.name, typeinfo, self._source_of(loop_stmt))
1360+
make_value(onnx_var.name, typeinfo, self._source_of(loop_stmt))
13621361
)
13631362
body = self._exit_scope()
13641363
inputs = [o_loop_bound, o_loop_condition] + [
@@ -1471,14 +1470,13 @@ def _translate_function_signature_common(
14711470
self._current_fn.append_parameter(attr)
14721471
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
14731472
else:
1474-
self._current_fn.append_parameter(
1475-
make_value(x.arg, typeinfo, self._source_of(x))
1476-
)
1473+
onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x))
1474+
self._current_fn.append_parameter(onnx_parameter)
14771475
self._used_vars.add(x.arg)
14781476
self._bind(
14791477
x.arg,
14801478
values.Dynamic(
1481-
ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x)
1479+
onnx_parameter, values.DynamicKind.Input, self._source_of(x)
14821480
),
14831481
)
14841482
if fn.returns:

0 commit comments

Comments
 (0)