|
19 | 19 | Sequence, |
20 | 20 | Tuple, |
21 | 21 | Union, |
| 22 | + _GenericAlias |
22 | 23 | ) |
23 | 24 |
|
24 | 25 | import onnx_ir as ir |
@@ -222,6 +223,49 @@ class ASTMeta: |
222 | 223 | live_in: set[str] | None = None |
223 | 224 |
|
224 | 225 |
|
| 226 | +class _ValueEnvironment: |
| 227 | + def __init__(self, converter: Converter): |
| 228 | + self._sym_value_to_onnx_values: dict[SymbolValue, ir.Value] = {} |
| 229 | + self._converter = converter |
| 230 | + |
| 231 | + def get_or_create_value( |
| 232 | + self, val: SymbolValue, info: sourceinfo.SourceInfo |
| 233 | + ) -> ir.Value: |
| 234 | + """Get or create an ONNX Value for a SymbolValue.""" |
| 235 | + if val in self._sym_value_to_onnx_values: |
| 236 | + return self._sym_value_to_onnx_values[val] |
| 237 | + if isinstance(val, AttrRef): |
| 238 | + # promote attribute to value |
| 239 | + result_name = self._converter._generate_unique_name("v") |
| 240 | + attr = _to_onnx_ref_attr(val, info) |
| 241 | + result = self._converter.emit([result_name], "Constant", [], attrs=[attr])[0] |
| 242 | + if ta.base_type_is_bool(val.typeinfo): |
| 243 | + # ONNX attributes use an int-encoding for bools, but ONNX tensor types |
| 244 | + # distinguish between int and bool. So we cast the int tensor to a bool tensor, |
| 245 | + # to promote a (python) bool attribute to a ONNX bool tensor. |
| 246 | + result_as_bool_name = self._converter._generate_unique_name(f"{result_name}_as_bool") |
| 247 | + result = self._converter.emit( |
| 248 | + [result_as_bool_name], |
| 249 | + "Cast", |
| 250 | + [result_name], |
| 251 | + attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], |
| 252 | + )[0] |
| 253 | + |
| 254 | + self._sym_value_to_onnx_values[val] = result |
| 255 | + return result |
| 256 | + |
| 257 | + if isinstance(val, Dynamic): |
| 258 | + # A value in ONNX |
| 259 | + result = ir.Value(name=val.value) |
| 260 | + self._sym_value_to_onnx_values[val] = result |
| 261 | + return result |
| 262 | + |
| 263 | + # Assume value is a python-value convertible to a tensor |
| 264 | + result = self._converter._emit_const(val, None, info) |
| 265 | + self._sym_value_to_onnx_values[val] = result |
| 266 | + return result |
| 267 | + |
| 268 | + |
225 | 269 | class Converter: |
226 | 270 | """Main class to translate python code into ONNX operators. |
227 | 271 |
|
@@ -392,39 +436,6 @@ def _generate_unique_name(self, candidate: str = "tmp") -> str: |
392 | 436 | self._used_vars.add(r) |
393 | 437 | return r |
394 | 438 |
|
395 | | - def _to_onnx_var( |
396 | | - self, |
397 | | - val: values.SymbolValue | PyValue, |
398 | | - target: PreferredName = "tmp", |
399 | | - *, |
400 | | - info: sourceinfo.SourceInfo, |
401 | | - ) -> ir.Value: |
402 | | - """Convert a Python or symbolic value to an ONNX Value.""" |
403 | | - if isinstance(val, values.AttrRef): |
404 | | - # promote attribute to value |
405 | | - result = self._generate_unique_name(target) |
406 | | - attr = _to_onnx_ref_attr(val, info) |
407 | | - result_val = self.emit([result], "Constant", [], attrs=[attr])[0] |
408 | | - if ta.base_type_is_bool(val.typeinfo): |
409 | | - # ONNX attributes use an int-encoding for bools, but ONNX tensor types |
410 | | - # distinguish between int and bool. So we cast the int tensor to a bool tensor, |
411 | | - # to promote a (python) bool attribute to a ONNX bool tensor. |
412 | | - result_as_bool = self._generate_unique_name(result + "_as_bool") |
413 | | - return self.emit( |
414 | | - [result_as_bool], |
415 | | - "Cast", |
416 | | - [result], |
417 | | - attrs=[ir.AttrInt64("to", ir.DataType.BOOL)], |
418 | | - )[0] |
419 | | - return result_val |
420 | | - |
421 | | - if isinstance(val, values.Dynamic): |
422 | | - # A value in ONNX |
423 | | - return ir.Value(name=val.value) |
424 | | - |
425 | | - # Assume value is a python-value convertible to a tensor |
426 | | - return self._emit_const(val, target, info) |
427 | | - |
428 | 439 | def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: |
429 | 440 | """Convert a python variable to an ONNX variable.""" |
430 | 441 | return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) |
|
0 commit comments