Skip to content

Commit f8b1dfe

Browse files
Added new input_get functions to run_pretrained_models.py
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8eb3331 commit f8b1dfe

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/run_pretrained_models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,21 @@ def get_zeros_int64(shape):
9292
"""Get zeros."""
9393
return np.zeros(shape).astype(np.int64)
9494

95+
def get_ones_int32(shape):
96+
"""Get ones."""
97+
return np.ones(shape).astype(np.int32)
98+
99+
def get_small_rand_int32(shape):
100+
"""Get random ints in range [1, 99]"""
101+
return np.random.randint(low=1, high=100, size=shape, dtype=np.int32)
102+
103+
def get_zeros_then_ones(shape):
104+
"""Fill half the tensor with zeros and the rest with ones"""
105+
cnt = np.prod(shape)
106+
zeros_cnt = cnt // 2
107+
ones_cnt = cnt - zeros_cnt
108+
return np.concatenate((np.zeros(zeros_cnt, dtype=np.int32), np.ones(ones_cnt, dtype=np.int32))).reshape(shape)
109+
95110
def get_wav(shape):
96111
"""Get sound data."""
97112
return np.sin(np.linspace(-np.pi, np.pi, shape[0]), dtype=np.float32)
@@ -107,8 +122,12 @@ def get_wav(shape):
107122
"get_wav": get_wav,
108123
"get_zeros_int32": get_zeros_int32,
109124
"get_zeros_int64": get_zeros_int64,
125+
"get_ones_int32": get_ones_int32,
126+
"get_small_rand_int32": get_small_rand_int32,
127+
"get_zeros_then_ones": get_zeros_then_ones
110128
}
111129

130+
112131
OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
113132

114133

0 commit comments

Comments
 (0)