Skip to content

Commit 1ed0964

Browse files
authored
Merge pull request #47 from onnx/gs/onnx-1.2
update test results
2 parents 315971d + 53ae41a commit 1ed0964

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

tests/run_pretrained_models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class Test(object):
7979
cache_dir = None
8080

8181
def __init__(self, url, local, make_input, input_names, output_names,
82-
disabled=False, more_inputs=None, rtol=0.01, atol=0.):
82+
disabled=False, more_inputs=None, rtol=0.01, atol=0., check_only_shape=False):
8383
self.url = url
8484
self.make_input = make_input
8585
self.local = local
@@ -89,6 +89,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
8989
self.more_inputs = more_inputs
9090
self.rtol = rtol
9191
self.atol = atol
92+
self.check_only_shape = check_only_shape
9293

9394
def download_file(self):
9495
"""Download file from url."""
@@ -239,7 +240,12 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
239240
print("\trun_onnx OK")
240241

241242
try:
242-
np.testing.assert_allclose(tf_results, onnx_results, rtol=self.rtol, atol=self.atol)
243+
if self.check_only_shape:
244+
for i in range(len(tf_results)):
245+
np.testing.assert_array_equal(tf_results[i].shape, onnx_results[i].shape)
246+
else:
247+
for i in range(len(tf_results)):
248+
np.testing.assert_allclose(tf_results[i], onnx_results[i], rtol=self.rtol, atol=self.atol)
243249
print("\tResults: OK")
244250
return True
245251
except Exception as ex:
@@ -276,7 +282,7 @@ def tests_from_yaml(fname):
276282
input_func = v.get("input_get")
277283
input_func = _INPUT_FUNC_MAPPING[input_func]
278284
kwargs = {}
279-
for kw in ["rtol", "atol", "disabled", "more_inputs"]:
285+
for kw in ["rtol", "atol", "disabled", "more_inputs", "check_only_shape"]:
280286
if v.get(kw) is not None:
281287
kwargs[kw] = v[kw]
282288

tests/unity.yaml

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
3DBall:
2-
# needs RandomNormal
2+
# caffe2: needs RandomNormal
3+
# onnxmsrtnext: int->float cast fails
34
disabled: true
45
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/3DBall/TFModels/3DBall.bytes
56
model: 3DBall.bytes
@@ -32,18 +33,16 @@ BananaIL:
3233
- action:0
3334

3435
BananaRL:
35-
# needs RandomNormal
36-
disabled: true
3736
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaRL.bytes
3837
model: BananaRL.bytes
38+
check_only_shape: true
3939
input_get: get_random
4040
inputs:
4141
"vector_observation:0": [1, 159]
4242
outputs:
4343
- action_probs:0
4444
- value_estimate:0
4545

46-
4746
Basic:
4847
# needs: onehot
4948
disabled: true
@@ -58,7 +57,8 @@ Basic:
5857
- value_estimate:0
5958

6059
Bouncer:
61-
# needs RandomNormal
60+
# caffe2: needs RandomNormal
61+
# onnxmsrtnext: int->float cast fails
6262
disabled: true
6363
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Bouncer/TFModels/Bouncer.bytes
6464
model: Bouncer.bytes
@@ -70,7 +70,8 @@ Bouncer:
7070
- value_estimate:0
7171

7272
crawler:
73-
# needs RandomNormal
73+
# caffe2: needs RandomNormal
74+
# onnxmsrtnext: int->float cast fails
7475
disabled: true
7576
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Crawler/TFModels/crawler.bytes
7677
model: crawler.bytes
@@ -82,10 +83,9 @@ crawler:
8283
- value_estimate:0
8384

8485
GridWorld_3x3:
85-
# needs Multinomial
86-
disabled: true
8786
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_3x3.bytes
8887
model: GridWorld_3x3.bytes
88+
check_only_shape: true
8989
input_get: get_random
9090
inputs:
9191
"visual_observation_0:0": [1, 84, 84, 3]
@@ -95,11 +95,10 @@ GridWorld_3x3:
9595
- action:0
9696

9797
GridWorld_5x5:
98-
# needs Multinomial
99-
disabled: true
10098
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorld_5x5.bytes
10199
model: GridWorld_5x5.bytes
102100
input_get: get_random
101+
check_only_shape: true
103102
inputs:
104103
"visual_observation_0:0": [1, 84, 84, 3]
105104
outputs:
@@ -115,14 +114,15 @@ Hallway:
115114
input_get: get_random
116115
inputs:
117116
"vector_observation:0": [1, 36]
117+
"recurrent_in:0": [1, 256]
118+
"sequence_length:0": [256]
119+
"prev_action:0": [1]
118120
outputs:
119121
- action_probs:0
120122
- value_estimate:0
121123
- action:0
122124

123125
PushBlock:
124-
# needs: Multinomial
125-
disabled: true
126126
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlock.bytes
127127
model: PushBlock.bytes
128128
input_get: get_random
@@ -134,7 +134,8 @@ PushBlock:
134134
- action:0
135135

136136
Reacher:
137-
# needs RandomNormal
137+
# caffe2: needs RandomNormal
138+
# onnxmsrtnext: int->float cast fails
138139
disabled: true
139140
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Reacher/TFModels/Reacher.bytes
140141
model: Reacher.bytes
@@ -146,10 +147,9 @@ Reacher:
146147
- action:0
147148

148149
Soccer:
149-
# needs Multinomial
150-
disabled: true
151150
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Soccer/TFModels/Soccer.bytes
152151
model: Soccer.bytes
152+
check_only_shape: true
153153
input_get: get_random
154154
inputs:
155155
"GoalieBrain/vector_observation:0": [1, 336]
@@ -163,8 +163,6 @@ Soccer:
163163
- StrikerBrain/value_estimate:0
164164

165165
Tennis:
166-
# needs RandomNormal
167-
disabled: true
168166
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.bytes
169167
model: Tennis.bytes
170168
input_get: get_random
@@ -174,10 +172,9 @@ Tennis:
174172
- value_estimate:0
175173

176174
WallJump:
177-
# needs Multinomial
178-
disabled: true
179175
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/WallJump/TFModels/WallJump.bytes
180176
model: WallJump.bytes
177+
check_only_shape: true
181178
input_get: get_random
182179
inputs:
183180
"SmallWallBrain/vector_observation:0": [1, 444]

0 commit comments

Comments
 (0)