Skip to content

Commit 8de75c9

Browse files
committed
Fix logic for converting np array to text
In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. Signed-off-by: Justin Chu <[email protected]>
1 parent 75c1a4d commit 8de75c9

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

onnxscript/backend/onnx_export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Any, Optional, Sequence
66

7-
import numpy
7+
import numpy as np
88
import onnx
99
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
1010

@@ -384,17 +384,17 @@ def _translate_attributes(self, node):
384384
if isinstance(value, str):
385385
attributes.append((at.name, f"{value!r}"))
386386
continue
387-
if isinstance(value, numpy.ndarray):
387+
if isinstance(value, np.ndarray):
388388
onnx_dtype = at.t.data_type
389389
if len(value.shape) == 0:
390390
text = (
391391
f'make_tensor("value", {onnx_dtype}, dims=[], '
392-
f"vals=[{value.tolist()!r}])"
392+
f"vals=[{repr(value.tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')}])"
393393
)
394394
else:
395395
text = (
396396
f'make_tensor("value", {onnx_dtype}, dims={list(value.shape)!r}, '
397-
f"vals={value.ravel().tolist()!r})"
397+
f"vals={repr(value.ravel().tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')})"
398398
)
399399
attributes.append((at.name, text))
400400
continue
@@ -738,7 +738,7 @@ def generate_rand(name: str, value: TensorProto) -> str:
738738
raise NotImplementedError(
739739
f"Unable to generate random initializer for data type {value.data_type}."
740740
)
741-
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)"
741+
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
742742

743743
random_initializer_values = "\n".join(
744744
generate_rand(key, value) for key, value in self.skipped_initializers.items()
@@ -793,7 +793,7 @@ def add(line: str) -> None:
793793
result.append(line)
794794

795795
# Generic imports.
796-
add("import numpy")
796+
add("import numpy as np")
797797
add("from onnx import TensorProto")
798798
add("from onnx.helper import make_tensor")
799799
add("from onnxscript import script, external_tensor")
@@ -873,11 +873,11 @@ def export2python(
873873
.. runpython::
874874
:showcode:
875875
:process:
876-
import numpy
876+
import numpy as np
877877
from sklearn.cluster import KMeans
878878
from mlprodict.onnx_conv import to_onnx
879879
from mlprodict.onnx_tools.onnx_export import export2python
880-
X = numpy.arange(20).reshape(10, 2).astype(numpy.float32)
880+
X = np.arange(20).reshape(10, 2).astype(np.float32)
881881
tr = KMeans(n_clusters=2)
882882
tr.fit(X)
883883
onx = to_onnx(tr, X, target_opset=14)

0 commit comments

Comments
 (0)