Skip to content

Commit 74e3855

Browse files
Merge pull request #1113 from onnx/tom/ImproveRunPretrainedGetInputs
Rescaled get_beach and added get_zeros_int32 and get_zeros_int64
2 parents b4ac342 + 8b72bfb commit 74e3855

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

tests/run_pretrained_models.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_beach(shape):
5757
img = img.resize(resize_to, PIL.Image.ANTIALIAS)
5858
img_np = np.array(img).astype(np.float32)
5959
img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape)
60-
return img_np
60+
return img_np / 255
6161

6262

6363
def get_random(shape):
@@ -84,6 +84,18 @@ def get_zeros(shape):
8484
"""Get zeros."""
8585
return np.zeros(shape).astype(np.float32)
8686

87+
def get_zeros_int32(shape):
88+
"""Get zeros."""
89+
return np.zeros(shape).astype(np.int32)
90+
91+
def get_zeros_int64(shape):
92+
"""Get zeros."""
93+
return np.zeros(shape).astype(np.int64)
94+
95+
def get_wav(shape):
96+
"""Get sound data."""
97+
return np.sin(np.linspace(-np.pi, np.pi, shape[0]), dtype=np.float32)
98+
8799

88100
_INPUT_FUNC_MAPPING = {
89101
"get_beach": get_beach,
@@ -92,6 +104,9 @@ def get_zeros(shape):
92104
"get_ramp": get_ramp,
93105
"get_ones": get_ones,
94106
"get_zeros": get_zeros,
107+
"get_wav": get_wav,
108+
"get_zeros_int32": get_zeros_int32,
109+
"get_zeros_int64": get_zeros_int64,
95110
}
96111

97112
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
@@ -137,7 +152,10 @@ def __init__(self, url, local, input_func, input_names, output_names,
137152
def make_input(self, v):
138153
"""Allows each input to specify its own function while defaulting to the input_get function"""
139154
if isinstance(v, dict):
140-
return _INPUT_FUNC_MAPPING[v["input_get"]](v["shape"])
155+
if "input_get" in v:
156+
return _INPUT_FUNC_MAPPING[v["input_get"]](v["shape"])
157+
if "value" in v:
158+
return np.array(v["value"])
141159
return self.input_func(v)
142160

143161
def download_model(self):

0 commit comments

Comments
 (0)