@@ -57,7 +57,7 @@ def get_beach(shape):
57
57
img = img .resize (resize_to , PIL .Image .ANTIALIAS )
58
58
img_np = np .array (img ).astype (np .float32 )
59
59
img_np = np .stack ([img_np ] * shape [0 ], axis = 0 ).reshape (shape )
60
- return img_np
60
+ return img_np / 255
61
61
62
62
63
63
def get_random (shape ):
@@ -84,6 +84,18 @@ def get_zeros(shape):
84
84
"""Get zeros."""
85
85
return np .zeros (shape ).astype (np .float32 )
86
86
87
+ def get_zeros_int32 (shape ):
88
+ """Get zeros."""
89
+ return np .zeros (shape ).astype (np .int32 )
90
+
91
+ def get_zeros_int64 (shape ):
92
+ """Get zeros."""
93
+ return np .zeros (shape ).astype (np .int64 )
94
+
95
+ def get_wav (shape ):
96
+ """Get sound data."""
97
+ return np .sin (np .linspace (- np .pi , np .pi , shape [0 ]), dtype = np .float32 )
98
+
87
99
88
100
_INPUT_FUNC_MAPPING = {
89
101
"get_beach" : get_beach ,
@@ -92,6 +104,9 @@ def get_zeros(shape):
92
104
"get_ramp" : get_ramp ,
93
105
"get_ones" : get_ones ,
94
106
"get_zeros" : get_zeros ,
107
+ "get_wav" : get_wav ,
108
+ "get_zeros_int32" : get_zeros_int32 ,
109
+ "get_zeros_int64" : get_zeros_int64 ,
95
110
}
96
111
97
112
OpsetConstraint = namedtuple ("OpsetConstraint" , "domain, min_version, max_version, excluded_version" )
@@ -137,7 +152,10 @@ def __init__(self, url, local, input_func, input_names, output_names,
137
152
def make_input (self , v ):
138
153
"""Allows each input to specify its own function while defaulting to the input_get function"""
139
154
if isinstance (v , dict ):
140
- return _INPUT_FUNC_MAPPING [v ["input_get" ]](v ["shape" ])
155
+ if "input_get" in v :
156
+ return _INPUT_FUNC_MAPPING [v ["input_get" ]](v ["shape" ])
157
+ if "value" in v :
158
+ return np .array (v ["value" ])
141
159
return self .input_func (v )
142
160
143
161
def download_model (self ):
0 commit comments