Skip to content

Commit 0cd1f20

Browse files
committed
_ValueEnvironment
Signed-off-by: Justin Chu <[email protected]>
1 parent 11a735e commit 0cd1f20

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

onnxscript/_converter.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Sequence,
2020
Tuple,
2121
Union,
22+
_GenericAlias
2223
)
2324

2425
import onnx_ir as ir
@@ -222,6 +223,49 @@ class ASTMeta:
222223
live_in: set[str] | None = None
223224

224225

226+
class _ValueEnvironment:
227+
def __init__(self, converter: Converter):
228+
self._sym_value_to_onnx_values: dict[SymbolValue, ir.Value] = {}
229+
self._converter = converter
230+
231+
def get_or_create_value(
232+
self, val: SymbolValue, info: sourceinfo.SourceInfo
233+
) -> ir.Value:
234+
"""Get or create an ONNX Value for a SymbolValue."""
235+
if val in self._sym_value_to_onnx_values:
236+
return self._sym_value_to_onnx_values[val]
237+
if isinstance(val, AttrRef):
238+
# promote attribute to value
239+
result_name = self._converter._generate_unique_name("v")
240+
attr = _to_onnx_ref_attr(val, info)
241+
result = self._converter.emit([result_name], "Constant", [], attrs=[attr])[0]
242+
if ta.base_type_is_bool(val.typeinfo):
243+
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
244+
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
245+
# to promote a (python) bool attribute to a ONNX bool tensor.
246+
result_as_bool_name = self._converter._generate_unique_name(f"{result_name}_as_bool")
247+
result = self._converter.emit(
248+
[result_as_bool_name],
249+
"Cast",
250+
[result_name],
251+
attrs=[ir.AttrInt64("to", ir.DataType.BOOL)],
252+
)[0]
253+
254+
self._sym_value_to_onnx_values[val] = result
255+
return result
256+
257+
if isinstance(val, Dynamic):
258+
# A value in ONNX
259+
result = ir.Value(name=val.value)
260+
self._sym_value_to_onnx_values[val] = result
261+
return result
262+
263+
# Assume value is a python-value convertible to a tensor
264+
result = self._converter._emit_const(val, None, info)
265+
self._sym_value_to_onnx_values[val] = result
266+
return result
267+
268+
225269
class Converter:
226270
"""Main class to translate python code into ONNX operators.
227271
@@ -392,39 +436,6 @@ def _generate_unique_name(self, candidate: str = "tmp") -> str:
392436
self._used_vars.add(r)
393437
return r
394438

395-
def _to_onnx_var(
396-
self,
397-
val: values.SymbolValue | PyValue,
398-
target: PreferredName = "tmp",
399-
*,
400-
info: sourceinfo.SourceInfo,
401-
) -> ir.Value:
402-
"""Convert a Python or symbolic value to an ONNX Value."""
403-
if isinstance(val, values.AttrRef):
404-
# promote attribute to value
405-
result = self._generate_unique_name(target)
406-
attr = _to_onnx_ref_attr(val, info)
407-
result_val = self.emit([result], "Constant", [], attrs=[attr])[0]
408-
if ta.base_type_is_bool(val.typeinfo):
409-
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
410-
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
411-
# to promote a (python) bool attribute to a ONNX bool tensor.
412-
result_as_bool = self._generate_unique_name(result + "_as_bool")
413-
return self.emit(
414-
[result_as_bool],
415-
"Cast",
416-
[result],
417-
attrs=[ir.AttrInt64("to", ir.DataType.BOOL)],
418-
)[0]
419-
return result_val
420-
421-
if isinstance(val, values.Dynamic):
422-
# A value in ONNX
423-
return ir.Value(name=val.value)
424-
425-
# Assume value is a python-value convertible to a tensor
426-
return self._emit_const(val, target, info)
427-
428439
def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable:
429440
"""Convert a python variable to an ONNX variable."""
430441
return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info)

0 commit comments

Comments
 (0)