@@ -92,6 +92,21 @@ def get_zeros_int64(shape):
92
92
"""Get zeros."""
93
93
return np .zeros (shape ).astype (np .int64 )
94
94
95
+ def get_ones_int32 (shape ):
96
+ """Get ones."""
97
+ return np .ones (shape ).astype (np .int32 )
98
+
99
+ def get_small_rand_int32 (shape ):
100
+ """Get random ints in range [1, 99]"""
101
+ return np .random .randint (low = 1 , high = 100 , size = shape , dtype = np .int32 )
102
+
103
+ def get_zeros_then_ones (shape ):
104
+ """Fill half the tensor with zeros and the rest with ones"""
105
+ cnt = np .prod (shape )
106
+ zeros_cnt = cnt // 2
107
+ ones_cnt = cnt - zeros_cnt
108
+ return np .concatenate ((np .zeros (zeros_cnt , dtype = np .int32 ), np .ones (ones_cnt , dtype = np .int32 ))).reshape (shape )
109
+
95
110
def get_wav (shape ):
96
111
"""Get sound data."""
97
112
return np .sin (np .linspace (- np .pi , np .pi , shape [0 ]), dtype = np .float32 )
@@ -107,8 +122,12 @@ def get_wav(shape):
107
122
"get_wav" : get_wav ,
108
123
"get_zeros_int32" : get_zeros_int32 ,
109
124
"get_zeros_int64" : get_zeros_int64 ,
125
+ "get_ones_int32" : get_ones_int32 ,
126
+ "get_small_rand_int32" : get_small_rand_int32 ,
127
+ "get_zeros_then_ones" : get_zeros_then_ones
110
128
}
111
129
130
+
112
131
OpsetConstraint = namedtuple ("OpsetConstraint" , "domain, min_version, max_version, excluded_version" )
113
132
114
133
0 commit comments